diff --git a/server/routers/chat_router.py b/server/routers/chat_router.py index 030ede903..0829586ad 100644 --- a/server/routers/chat_router.py +++ b/server/routers/chat_router.py @@ -542,7 +542,16 @@ async def stream_messages(): # 构造运行时配置,如果没有thread_id则生成一个 user_id = str(current_user.id) thread_id = config.get("thread_id") - input_context = {"user_id": user_id, "thread_id": thread_id} + user_priority_knowledge = config.get("user_priority_knowledge") + user_knowledges = config.get("user_knowledges") + user_mcp_tools = config.get("user_mcp_tools") + input_context = { + "user_id": user_id, + "thread_id": thread_id, + "user_priority_knowledge": user_priority_knowledge, + "user_knowledges": user_knowledges, + "user_mcp_tools": user_mcp_tools, + } if not thread_id: thread_id = str(uuid.uuid4()) diff --git a/server/routers/evaluation_router.py b/server/routers/evaluation_router.py index 2c12575b1..1db786c6f 100644 --- a/server/routers/evaluation_router.py +++ b/server/routers/evaluation_router.py @@ -14,6 +14,7 @@ # 评估基准 # ============================================================================ + @evaluation.get("/databases/{db_id}/benchmarks/{benchmark_id}") async def get_evaluation_benchmark_by_db( db_id: str, benchmark_id: str, page: int = 1, page_size: int = 10, current_user: User = Depends(get_admin_user) diff --git a/server/routers/mcp_router.py b/server/routers/mcp_router.py index d886ec019..f773e5802 100644 --- a/server/routers/mcp_router.py +++ b/server/routers/mcp_router.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from src.services.mcp_service import ( + clear_mcp_server_tools_cache, create_mcp_server, get_mcp_tools_stats, delete_mcp_server, @@ -219,10 +220,12 @@ async def test_mcp_server( ): """测试 MCP 服务器连接""" try: - await get_server_or_404(db, name) + server = await get_server_or_404(db, name) try: - tools = await get_all_mcp_tools(name) + tools = await get_all_mcp_tools( + name, force_refresh=True, additional_servers={server.name: server.to_mcp_config()} + ) return { "success": True, "message": f"连接成功,共发现 {len(tools)} 个工具", @@ -276,7 +279,7 @@ async def get_mcp_server_tools( try: # 获取所有工具(不过滤 disabled_tools) - tools = await get_all_mcp_tools(name) + tools = await get_all_mcp_tools(name, additional_servers={server.name: server.to_mcp_config()}) tool_list = [] for tool in tools: @@ -291,7 +294,7 @@ async def get_mcp_server_tools( } # 提取参数信息 if hasattr(tool, "args_schema") and tool.args_schema: - schema = tool.args_schema.schema() if hasattr(tool.args_schema, "schema") else {} + schema = tool.args_schema tool_info["parameters"] = schema.get("properties", {}) tool_info["required"] = schema.get("required", []) else: @@ -322,11 +325,16 @@ async def refresh_mcp_server_tools( ): """刷新 MCP 服务器的工具列表(清除缓存重新获取)""" try: - await get_server_or_404(db, name) + server = await get_server_or_404(db, name) try: + # 清除缓存,强制刷新 + clear_mcp_server_tools_cache(name) + # 获取所有工具(不过滤 disabled_tools) - tools = await get_all_mcp_tools(name) + tools = await get_all_mcp_tools( + name, cache=True, force_refresh=True, additional_servers={server.name: server.to_mcp_config()} + ) # 获取统计信息 stats = get_mcp_tools_stats(name) diff --git a/server/routers/system_router.py b/server/routers/system_router.py index 7be72c3d3..7dabc111d 100644 --- a/server/routers/system_router.py +++ b/server/routers/system_router.py @@ -66,7 +66,7 @@ async def get_system_logs(levels: str | None = None, current_user: User = Depend if levels: level_filter = set(level.strip().upper() for level in levels.split(",") if level.strip()) - async with aiofiles.open(LOG_FILE) as f: + async with aiofiles.open(LOG_FILE, encoding="utf-8") as f: # 读取最后1000行 lines = [] async for line in f: diff --git a/src/agents/chatbot/graph.py b/src/agents/chatbot/graph.py index 810bdb1a8..5425a7d7c 100644 --- a/src/agents/chatbot/graph.py +++ b/src/agents/chatbot/graph.py @@ -2,10 +2,8 @@ from langchain.agents.middleware import ModelRetryMiddleware from src.agents.common import BaseAgent, load_chat_model -from src.agents.common.middlewares import ( - inject_attachment_context, -) -from src.agents.common.tools import get_tools_from_context +from src.agents.common.middlewares import DynamicToolMiddleware, inject_attachment_context +from src.agents.common.tools import get_basic_tools, get_mcp_tools, get_tools_from_context class ChatbotAgent(BaseAgent): @@ -24,6 +22,12 @@ async def get_graph(self, **kwargs): # 获取上下文配置 context = self.context_schema.from_file(module_name=self.module_name) + dynamic_tool_middleware = DynamicToolMiddleware( + base_tools=[], + basic_tools=get_basic_tools(context.tools), + mcp_tools=await get_mcp_tools(server_names=context.mcps), + ) + # 使用 create_agent 创建智能体 graph = create_agent( model=load_chat_model(context.model), # 使用 context 中的模型配置 @@ -32,6 +36,7 @@ async def get_graph(self, **kwargs): middleware=[ inject_attachment_context, # 附件上下文注入 ModelRetryMiddleware(), # 模型重试中间件 + dynamic_tool_middleware, ], checkpointer=await self._get_checkpointer(), ) diff --git a/src/agents/common/context.py b/src/agents/common/context.py index 190fb7b2e..521f06856 100644 --- a/src/agents/common/context.py +++ b/src/agents/common/context.py @@ -10,7 +10,7 @@ from src import config as sys_config from src.knowledge import knowledge_base -from src.services.mcp_service import get_mcp_server_names +from src.services.mcp_service import get_mcp_servers_info from src.utils import logger from .tools import gen_tool_info, get_buildin_tools @@ -27,6 +27,11 @@ class BaseContext: 3. 类默认配置:最低优先级,类中定义的默认值 """ + def set(self, data: dict): + """设置配置字段""" + for key, value in data.items(): + setattr(self, key, value) + def update(self, data: dict): """更新配置字段""" for key, value in data.items(): @@ -76,11 +81,11 @@ def update(self, data: dict): }, ) - mcps: list[str] = field( + mcps: Annotated[list[dict], {"__template_metadata__": {"kind": "mcps"}}] = field( default_factory=list, metadata={ "name": "MCP服务器", - "options": lambda: get_mcp_server_names(), + "options": lambda: get_mcp_servers_info(), "description": ( "MCP服务器列表,建议使用支持 SSE 的 MCP 服务器," "如果需要使用 uvx 或 npx 运行的服务器,也请在项目外部启动 MCP 服务器,并在项目中配置 MCP 服务器。" @@ -106,7 +111,7 @@ def from_file(cls, module_name: str, input_context: dict = None) -> "BaseContext context.update(file_config) if input_context: - context.update(input_context) + context.set(input_context) return context diff --git a/src/agents/common/middlewares/dynamic_tool_middleware.py b/src/agents/common/middlewares/dynamic_tool_middleware.py index 6f11e4c9f..1930e495f 100644 --- a/src/agents/common/middlewares/dynamic_tool_middleware.py +++ b/src/agents/common/middlewares/dynamic_tool_middleware.py @@ -14,7 +14,14 @@ class DynamicToolMiddleware(AgentMiddleware): 运行时只是根据配置筛选工具,不能动态添加新工具 """ - def __init__(self, base_tools: list[Any], mcp_servers: list[str] | None = None): + def __init__( + self, + base_tools: list[Any], + mcp_servers: list[str] | None = None, + basic_tools: list[Any] | None = None, + kb_tools: list[Any] | None = None, + mcp_tools: list[Any] | None = None, + ): """初始化中间件 Args: @@ -26,6 +33,10 @@ def __init__(self, base_tools: list[Any], mcp_servers: list[str] | None = None): self._all_mcp_tools: dict[str, list[Any]] = {} # 所有已加载的 MCP 工具 self._mcp_servers = mcp_servers or [] + self.basic_tools: list[Any] = basic_tools or [] # 基础工具 + self.kb_tools: list[Any] = kb_tools or [] # 基于知识库的工具 + self.mcp_tools: list[Any] = mcp_tools or [] # 基于MCP的工具 + async def initialize_mcp_tools(self) -> None: """异步初始化:预加载所有可能用到的 MCP 工具""" for mcp_name in self._mcp_servers: @@ -42,26 +53,35 @@ async def awrap_model_call( ) -> ModelResponse: """根据配置动态选择工具(从已注册的工具中筛选)""" # 从 runtime context 获取配置 - selected_tools = request.runtime.context.tools - selected_mcps = request.runtime.context.mcps + selected_tools = getattr(request.runtime.context, "tools", []) + selected_mcp_tools = getattr(request.runtime.context, "user_mcp_tools", []) + selected_knowledges = getattr(request.runtime.context, "user_knowledges", []) enabled_tools = [] # 根据配置筛选基础工具 if selected_tools and isinstance(selected_tools, list) and len(selected_tools) > 0: - enabled_tools = [tool for tool in self.tools if tool.name in selected_tools] + enabled_tools = [tool for tool in self.basic_tools if tool.name in selected_tools] # 根据配置筛选 MCP 工具(从已注册的工具中选择) - if selected_mcps and isinstance(selected_mcps, list) and len(selected_mcps) > 0: - for mcp in selected_mcps: - if mcp in self._all_mcp_tools: - enabled_tools.extend(self._all_mcp_tools[mcp]) - else: - logger.warning(f"MCP server '{mcp}' not pre-loaded. Please add it to mcp_servers list.") + if selected_mcp_tools and isinstance(selected_mcp_tools, list) and len(selected_mcp_tools) > 0: + mcp_tools_from_selected = [tool for tool in self.mcp_tools if tool.metadata["id"] in selected_mcp_tools] + enabled_tools.extend(mcp_tools_from_selected) + else: + enabled_tools.extend(self.mcp_tools) + + # 筛选知识工具 + if selected_knowledges and isinstance(selected_knowledges, list) and len(selected_knowledges) > 0: + knowledge_tools_from_selected = [tool for tool in self.kb_tools if tool.name in selected_knowledges] + enabled_tools.extend(knowledge_tools_from_selected) + else: + enabled_tools.extend(self.kb_tools) logger.info( f"Dynamic tool selection: {len(enabled_tools)} tools enabled: {[tool.name for tool in enabled_tools]}, " - f"selected_tools: {selected_tools}, selected_mcps: {selected_mcps}" + f"selected_tools: {selected_tools}, " + f"selected_mcp_tools: {selected_mcp_tools}, " + f"selected_knowledges: {selected_knowledges}" ) # 更新 request 中的工具列表 diff --git a/src/agents/common/tools.py b/src/agents/common/tools.py index d4dc0791d..94359a343 100644 --- a/src/agents/common/tools.py +++ b/src/agents/common/tools.py @@ -360,18 +360,39 @@ def get_buildin_tools() -> list: return static_tools +def get_basic_tools(tools: list[str] = None, extra_tools=None): + all_basic_tools = get_buildin_tools() + (extra_tools or []) + results = [] + + if tools: + # 创建工具映射表 + tools_map = {t.name: t for t in all_basic_tools} + for tool_name in tools: + if tool_name in tools_map: + results.append(tools_map[tool_name]) + + return results + + +async def get_mcp_tools(server_names: list[str] = None): + results = [] + + if server_names: + for server_name in server_names: + mcp_tools = await get_enabled_mcp_tools(server_name) + results.extend(mcp_tools) + + return results + + async def get_tools_from_context(context, extra_tools=None) -> list: """从上下文配置中获取工具列表""" # 1. 基础工具 (从 context.tools 中筛选) - all_basic_tools = get_buildin_tools() + (extra_tools or []) selected_tools = [] if context.tools: - # 创建工具映射表 - tools_map = {t.name: t for t in all_basic_tools} - for tool_name in context.tools: - if tool_name in tools_map: - selected_tools.append(tools_map[tool_name]) + basic_tools = get_basic_tools(context.tools, extra_tools) + selected_tools.extend(basic_tools) # 2. 知识库工具 if context.knowledges: @@ -380,8 +401,7 @@ async def get_tools_from_context(context, extra_tools=None) -> list: # 3. MCP 工具(使用统一入口,自动过滤 disabled_tools) if context.mcps: - for server_name in context.mcps: - mcp_tools = await get_enabled_mcp_tools(server_name) - selected_tools.extend(mcp_tools) + mcp_tools = await get_mcp_tools(server_names=context.mcps) + selected_tools.extend(mcp_tools) return selected_tools diff --git a/src/agents/reporter/graph.py b/src/agents/reporter/graph.py index d104b6813..15cbcdd2f 100644 --- a/src/agents/reporter/graph.py +++ b/src/agents/reporter/graph.py @@ -4,15 +4,13 @@ from langchain.agents import create_agent from src.agents.common import BaseAgent, BaseContext, load_chat_model -from src.agents.common.tools import gen_tool_info, get_buildin_tools from src.agents.common.toolkits.mysql import get_mysql_tools -from src.agents.common.tools import get_tools_from_context +from src.agents.common.tools import gen_tool_info, get_buildin_tools, get_tools_from_context from src.utils import logger @dataclass(kw_only=True) class ReporterContext(BaseContext): - # 覆盖默认的工具列表,添加 MySQL 工具包 tools: Annotated[list[dict], {"__template_metadata__": {"kind": "tools"}}] = field( default_factory=lambda: [t.name for t in get_mysql_tools()], diff --git a/src/services/mcp_service.py b/src/services/mcp_service.py index 90651a537..6b31cbe45 100644 --- a/src/services/mcp_service.py +++ b/src/services/mcp_service.py @@ -100,7 +100,7 @@ async def sync_mcp_server_to_cache(name: str, config: dict[str, Any] | None) -> logger.info(f"Synced MCP server '{name}' to cache") # Clear tools cache for this server - _mcp_tools_cache.pop(name, None) + # _mcp_tools_cache.pop(name, None) async def init_mcp_servers() -> None: @@ -201,6 +201,68 @@ def to_camel_case(s: str) -> str: return s +async def _fetch_tools_from_server(server_name: str, server_config: dict[str, Any]) -> list[Callable[..., Any]]: + """Fetch tools from MCP server and process them. + + This function handles: + 1. Connecting to the MCP server + 2. Fetching all tools + 3. Processing tool IDs and metadata + + Args: + server_name: Server name + server_config: Server configuration (without non-connection fields) + + Returns: + List of processed tools with unique IDs + """ + try: + # Extract connection config + client_config = { + k: v + for k, v in server_config.items() + if k + not in ( + "disabled_tools", + "description", + "icon", + "enabled", + ) + } + + client = await get_mcp_client({server_name: client_config}) + if client is None: + return [] + + # Get ALL tools (Raw) + raw_tools = cast(list[Any], await client.get_tools()) + + # Render IDs for ALL tools + server_cc = to_camel_case(server_name) + all_processed_tools = [] + for tool in raw_tools: + # Render unique ID rule: mcp__[camelCaseServer]__[camelCaseTool] + original_name = tool.name + tool_cc = to_camel_case(original_name) + unique_id = f"mcp__{server_cc}__{tool_cc}" + + # Use metadata to store + if tool.metadata is None: + tool.metadata = {} + tool.metadata["id"] = unique_id + + all_processed_tools.append(tool) + + return all_processed_tools + + except AssertionError as e: + logger.warning(f"[assert] Failed to load tools from MCP server '{server_name}': {e}") + return [] + except Exception as e: + logger.error(f"Failed to load tools from MCP server '{server_name}': {e}, traceback: {traceback.format_exc()}") + return [] + + async def get_mcp_tools( server_name: str, additional_servers: dict[str, dict] = None, @@ -236,61 +298,31 @@ async def get_mcp_tools( all_processed_tools = _mcp_tools_cache[server_name] else: # Need to fetch from server - try: - assert server_name in mcp_servers, f"Server {server_name} not found in ({list(mcp_servers.keys())})" + assert server_name in mcp_servers, f"Server {server_name} not found in ({list(mcp_servers.keys())})" - # Extract connection config - server_config = mcp_servers[server_name] - client_config = {k: v for k, v in server_config.items() if k not in ("disabled_tools",)} + # Extract connection config + server_config = mcp_servers[server_name] - client = await get_mcp_client({server_name: client_config}) - if client is None: - return [] + # Use the new helper function to fetch and process tools + all_processed_tools = await _fetch_tools_from_server(server_name, server_config) - # Get ALL tools (Raw) - raw_tools = cast(list[Any], await client.get_tools()) - - # Render IDs for ALL tools - server_cc = to_camel_case(server_name) - for tool in raw_tools: - # Render unique ID rule: mcp__[camelCaseServer]__[camelCaseTool] - original_name = tool.name - tool_cc = to_camel_case(original_name) - unique_id = f"mcp__{server_cc}__{tool_cc}" - - # Use metadata to store - if tool.metadata is None: - tool.metadata = {} - tool.metadata["id"] = unique_id - - all_processed_tools.append(tool) - - # Update Cache (Store the FULL list) - if cache: - _mcp_tools_cache[server_name] = all_processed_tools - - # Update Stats - # Stats should reflect the GLOBAL configuration state - # (How many are disabled in the stored config, not the transient arg) - global_config_disabled = mcp_servers.get(server_name, {}).get("disabled_tools") or [] - enabled_count = len([t for t in all_processed_tools if t.name not in global_config_disabled]) - - _mcp_tools_stats[server_name] = { - "total": len(all_processed_tools), - "enabled": enabled_count, - "disabled": len(all_processed_tools) - enabled_count, - } + # Update Cache (Store the FULL list) + if cache: + _mcp_tools_cache[server_name] = all_processed_tools - logger.info(f"Refreshed MCP tools cache for '{server_name}': {len(all_processed_tools)} tools loaded.") + # Update Stats + # Stats should reflect the GLOBAL configuration state + # (How many are disabled in the stored config, not the transient arg) + global_config_disabled = mcp_servers.get(server_name, {}).get("disabled_tools") or [] + enabled_count = len([t for t in all_processed_tools if t.name not in global_config_disabled]) - except AssertionError as e: - logger.warning(f"[assert] Failed to load tools from MCP server '{server_name}': {e}") - return [] - except Exception as e: - logger.error( - f"Failed to load tools from MCP server '{server_name}': {e}, traceback: {traceback.format_exc()}" - ) - return [] + _mcp_tools_stats[server_name] = { + "total": len(all_processed_tools), + "enabled": enabled_count, + "disabled": len(all_processed_tools) - enabled_count, + } + + logger.info(f"Refreshed MCP tools cache for '{server_name}': {len(all_processed_tools)} tools loaded.") # 3. Filtering (Apply to Return Value Only) if disabled_tools: @@ -540,7 +572,7 @@ async def toggle_tool_enabled( await db.commit() # Clear tool cache (re-filtered on next fetch) - clear_mcp_server_tools_cache(server_name) + # clear_mcp_server_tools_cache(server_name) logger.info(f"Toggled tool '{tool_name}' for server '{server_name}' enabled={enabled}") return enabled, server @@ -559,6 +591,83 @@ def get_mcp_server_names() -> list[str]: return list(MCP_SERVERS.keys()) +def get_mcp_servers_info(include_disabled: bool = False) -> list[dict[str, Any]]: + """Get MCP servers information with tools list. + + Args: + refresh_tools: Whether to force refresh tools from servers + include_disabled: Whether to include disabled servers in the result + + Returns: + List of server info dictionaries containing name, description, tools, etc. + """ + result = [] + + for server_name, config in MCP_SERVERS.items(): + # Skip disabled servers if requested + if not include_disabled and not config.get("enabled", True): + continue + + # Get server basic info + server_info = { + "id": server_name, + "name": server_name, + "description": config.get("description", ""), + "icon": config.get("icon", ""), + "enabled": config.get("enabled", True), + "transport": config.get("transport", ""), + "tools": [], + } + + def _format_parameters(input_schema: dict) -> list: + """将JSON Schema格式的参数转换为前端可直接使用的对象数组""" + if not input_schema or not input_schema.get("properties"): + return [] + required_params = input_schema.get("required", []) + return [ + { + "name": key, + "type": value.get("type", "unknown"), + "description": value.get("description", ""), + "required": key in required_params, + } + for key, value in input_schema["properties"].items() + ] + + # Get tools for this server + try: + # Get enabled tools (filter disabled_tools) + disabled_tools = config.get("disabled_tools") or [] + tools = _mcp_tools_cache.get(server_name, []) + enabled_tools = [t for t in tools if t.name not in disabled_tools] + # 转换为响应模型,包含参数信息 + tool_responses = [] + for tool in enabled_tools: + # 获取参数信息 + args_schema = getattr(tool, "args_schema", {}) or {} + parameters = _format_parameters(args_schema) + + tool_id = getattr(tool, "metadata", {}).get("id", "unknown") + tool_name = getattr(tool, "name", "unknown") + tool_description = getattr(tool, "description", "") + + tool_responses.append( + {"id": tool_id, "name": tool_name, "description": tool_description, "parameters": parameters} + ) + + # Filter out disabled tools + server_info["tools"] = tool_responses + + except Exception as e: + logger.warning(f"Failed to get tools for server '{server_name}': {e}") + # Keep server info but with empty tools list + server_info["tools"] = [] + + result.append(server_info) + + return list(result) + + async def get_enabled_mcp_tools(server_name: str) -> list: """Get MCP server tools (auto-filtering disabled_tools). @@ -594,7 +703,9 @@ async def get_servers_config(names: list[str]) -> dict[str, dict[str, Any]]: return {name: MCP_SERVERS[name] for name in names if name in MCP_SERVERS} -async def get_all_mcp_tools(server_name: str) -> list: +async def get_all_mcp_tools( + server_name: str, cache: bool = True, force_refresh: bool = False, additional_servers=None +) -> list: """Get all tools of an MCP server (no filtering). For management UI to display tool list, supports viewing all tools and their enabled status. @@ -602,6 +713,9 @@ async def get_all_mcp_tools(server_name: str) -> list: Args: server_name: Server name + cache: Whether to use/update the cache (default: True) + force_refresh: Whether to force a refresh from the server (default: False) + additional_servers: Additional server configurations Returns: List of all tools (unfiltered) @@ -609,7 +723,10 @@ async def get_all_mcp_tools(server_name: str) -> list: config = MCP_SERVERS.get(server_name) if not config: logger.warning(f"MCP server '{server_name}' not found in cache") - return [] + if not additional_servers: + return [] # Get all tools (no filtering, force refresh, no cache update) - return await get_mcp_tools(server_name, disabled_tools=[], cache=False, force_refresh=True) + return await get_mcp_tools( + server_name, additional_servers, disabled_tools=[], cache=cache, force_refresh=force_refresh + ) diff --git a/src/storage/db/models.py b/src/storage/db/models.py index 6606410a6..cc9790666 100644 --- a/src/storage/db/models.py +++ b/src/storage/db/models.py @@ -383,4 +383,11 @@ def to_mcp_config(self) -> dict: config["sse_read_timeout"] = self.sse_read_timeout if self.disabled_tools: config["disabled_tools"] = self.disabled_tools + + if self.description: + config["description"] = self.description + if self.icon: + config["icon"] = self.icon + if self.enabled: + config["enabled"] = self.enabled return config diff --git a/web/src/components/AgentChatComponent.vue b/web/src/components/AgentChatComponent.vue index be09856e0..ad41a8046 100644 --- a/web/src/components/AgentChatComponent.vue +++ b/web/src/components/AgentChatComponent.vue @@ -6,6 +6,7 @@ :is-sidebar-open="chatUIStore.isSidebarOpen" :is-initial-render="localUIState.isInitialRender" :single-mode="props.singleMode" + :is-loading="isProcessing" :agents="agents" :selected-agent-id="currentAgentId" :is-creating-new-chat="chatUIStore.creatingNewChat" @@ -79,10 +80,6 @@ 正在加载消息... -