Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion server/routers/chat_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,16 @@ async def stream_messages():
# 构造运行时配置,如果没有thread_id则生成一个
user_id = str(current_user.id)
thread_id = config.get("thread_id")
input_context = {"user_id": user_id, "thread_id": thread_id}
user_priority_knowledge = config.get("user_priority_knowledge")
user_knowledges = config.get("user_knowledges")
user_mcp_tools = config.get("user_mcp_tools")
input_context = {
"user_id": user_id,
"thread_id": thread_id,
"user_priority_knowledge": user_priority_knowledge,
"user_knowledges": user_knowledges,
"user_mcp_tools": user_mcp_tools,
}

if not thread_id:
thread_id = str(uuid.uuid4())
Expand Down
1 change: 1 addition & 0 deletions server/routers/evaluation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# 评估基准
# ============================================================================


@evaluation.get("/databases/{db_id}/benchmarks/{benchmark_id}")
async def get_evaluation_benchmark_by_db(
db_id: str, benchmark_id: str, page: int = 1, page_size: int = 10, current_user: User = Depends(get_admin_user)
Expand Down
20 changes: 14 additions & 6 deletions server/routers/mcp_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from src.services.mcp_service import (
clear_mcp_server_tools_cache,
create_mcp_server,
get_mcp_tools_stats,
delete_mcp_server,
Expand Down Expand Up @@ -219,10 +220,12 @@ async def test_mcp_server(
):
"""测试 MCP 服务器连接"""
try:
await get_server_or_404(db, name)
server = await get_server_or_404(db, name)

try:
tools = await get_all_mcp_tools(name)
tools = await get_all_mcp_tools(
name, force_refresh=True, additional_servers={server.name: server.to_mcp_config()}
)
return {
"success": True,
"message": f"连接成功,共发现 {len(tools)} 个工具",
Expand Down Expand Up @@ -276,7 +279,7 @@ async def get_mcp_server_tools(

try:
# 获取所有工具(不过滤 disabled_tools)
tools = await get_all_mcp_tools(name)
tools = await get_all_mcp_tools(name, additional_servers={server.name: server.to_mcp_config()})
tool_list = []

for tool in tools:
Expand All @@ -291,7 +294,7 @@ async def get_mcp_server_tools(
}
# 提取参数信息
if hasattr(tool, "args_schema") and tool.args_schema:
schema = tool.args_schema.schema() if hasattr(tool.args_schema, "schema") else {}
schema = tool.args_schema
tool_info["parameters"] = schema.get("properties", {})
tool_info["required"] = schema.get("required", [])
else:
Expand Down Expand Up @@ -322,11 +325,16 @@ async def refresh_mcp_server_tools(
):
"""刷新 MCP 服务器的工具列表(清除缓存重新获取)"""
try:
await get_server_or_404(db, name)
server = await get_server_or_404(db, name)

try:
# 清除缓存,强制刷新
clear_mcp_server_tools_cache(name)

# 获取所有工具(不过滤 disabled_tools)
tools = await get_all_mcp_tools(name)
tools = await get_all_mcp_tools(
name, cache=True, force_refresh=True, additional_servers={server.name: server.to_mcp_config()}
)

# 获取统计信息
stats = get_mcp_tools_stats(name)
Expand Down
2 changes: 1 addition & 1 deletion server/routers/system_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def get_system_logs(levels: str | None = None, current_user: User = Depend
if levels:
level_filter = set(level.strip().upper() for level in levels.split(",") if level.strip())

async with aiofiles.open(LOG_FILE) as f:
async with aiofiles.open(LOG_FILE, encoding="utf-8") as f:
# 读取最后1000行
lines = []
async for line in f:
Expand Down
13 changes: 9 additions & 4 deletions src/agents/chatbot/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from langchain.agents.middleware import ModelRetryMiddleware

from src.agents.common import BaseAgent, load_chat_model
from src.agents.common.middlewares import (
inject_attachment_context,
)
from src.agents.common.tools import get_tools_from_context
from src.agents.common.middlewares import DynamicToolMiddleware, inject_attachment_context
from src.agents.common.tools import get_basic_tools, get_mcp_tools, get_tools_from_context


class ChatbotAgent(BaseAgent):
Expand All @@ -24,6 +22,12 @@ async def get_graph(self, **kwargs):
# 获取上下文配置
context = self.context_schema.from_file(module_name=self.module_name)

dynamic_tool_middleware = DynamicToolMiddleware(
base_tools=[],
basic_tools=get_basic_tools(context.tools),
mcp_tools=await get_mcp_tools(server_names=context.mcps),
)

# 使用 create_agent 创建智能体
graph = create_agent(
model=load_chat_model(context.model), # 使用 context 中的模型配置
Expand All @@ -32,6 +36,7 @@ async def get_graph(self, **kwargs):
middleware=[
inject_attachment_context, # 附件上下文注入
ModelRetryMiddleware(), # 模型重试中间件
dynamic_tool_middleware,
],
checkpointer=await self._get_checkpointer(),
)
Expand Down
13 changes: 9 additions & 4 deletions src/agents/common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from src import config as sys_config
from src.knowledge import knowledge_base
from src.services.mcp_service import get_mcp_server_names
from src.services.mcp_service import get_mcp_servers_info
from src.utils import logger

from .tools import gen_tool_info, get_buildin_tools
Expand All @@ -27,6 +27,11 @@ class BaseContext:
3. 类默认配置:最低优先级,类中定义的默认值
"""

def set(self, data: dict):
"""设置配置字段"""
for key, value in data.items():
setattr(self, key, value)

def update(self, data: dict):
"""更新配置字段"""
for key, value in data.items():
Expand Down Expand Up @@ -76,11 +81,11 @@ def update(self, data: dict):
},
)

mcps: list[str] = field(
mcps: Annotated[list[dict], {"__template_metadata__": {"kind": "mcps"}}] = field(
default_factory=list,
metadata={
"name": "MCP服务器",
"options": lambda: get_mcp_server_names(),
"options": lambda: get_mcp_servers_info(),
"description": (
"MCP服务器列表,建议使用支持 SSE 的 MCP 服务器,"
"如果需要使用 uvx 或 npx 运行的服务器,也请在项目外部启动 MCP 服务器,并在项目中配置 MCP 服务器。"
Expand All @@ -106,7 +111,7 @@ def from_file(cls, module_name: str, input_context: dict = None) -> "BaseContext
context.update(file_config)

if input_context:
context.update(input_context)
context.set(input_context)

return context

Expand Down
42 changes: 31 additions & 11 deletions src/agents/common/middlewares/dynamic_tool_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ class DynamicToolMiddleware(AgentMiddleware):
运行时只是根据配置筛选工具,不能动态添加新工具
"""

def __init__(self, base_tools: list[Any], mcp_servers: list[str] | None = None):
def __init__(
self,
base_tools: list[Any],
mcp_servers: list[str] | None = None,
basic_tools: list[Any] | None = None,
kb_tools: list[Any] | None = None,
mcp_tools: list[Any] | None = None,
):
"""初始化中间件

Args:
Expand All @@ -26,6 +33,10 @@ def __init__(self, base_tools: list[Any], mcp_servers: list[str] | None = None):
self._all_mcp_tools: dict[str, list[Any]] = {} # 所有已加载的 MCP 工具
self._mcp_servers = mcp_servers or []

self.basic_tools: list[Any] = basic_tools or [] # 基础工具
self.kb_tools: list[Any] = kb_tools or [] # 基于知识库的工具
self.mcp_tools: list[Any] = mcp_tools or [] # 基于MCP的工具

async def initialize_mcp_tools(self) -> None:
"""异步初始化:预加载所有可能用到的 MCP 工具"""
for mcp_name in self._mcp_servers:
Expand All @@ -42,26 +53,35 @@ async def awrap_model_call(
) -> ModelResponse:
"""根据配置动态选择工具(从已注册的工具中筛选)"""
# 从 runtime context 获取配置
selected_tools = request.runtime.context.tools
selected_mcps = request.runtime.context.mcps
selected_tools = getattr(request.runtime.context, "tools", [])
selected_mcp_tools = getattr(request.runtime.context, "user_mcp_tools", [])
selected_knowledges = getattr(request.runtime.context, "user_knowledges", [])

enabled_tools = []

# 根据配置筛选基础工具
if selected_tools and isinstance(selected_tools, list) and len(selected_tools) > 0:
enabled_tools = [tool for tool in self.tools if tool.name in selected_tools]
enabled_tools = [tool for tool in self.basic_tools if tool.name in selected_tools]

# 根据配置筛选 MCP 工具(从已注册的工具中选择)
if selected_mcps and isinstance(selected_mcps, list) and len(selected_mcps) > 0:
for mcp in selected_mcps:
if mcp in self._all_mcp_tools:
enabled_tools.extend(self._all_mcp_tools[mcp])
else:
logger.warning(f"MCP server '{mcp}' not pre-loaded. Please add it to mcp_servers list.")
if selected_mcp_tools and isinstance(selected_mcp_tools, list) and len(selected_mcp_tools) > 0:
mcp_tools_from_selected = [tool for tool in self.mcp_tools if tool.metadata["id"] in selected_mcp_tools]
enabled_tools.extend(mcp_tools_from_selected)
else:
enabled_tools.extend(self.mcp_tools)

# 筛选知识工具
if selected_knowledges and isinstance(selected_knowledges, list) and len(selected_knowledges) > 0:
knowledge_tools_from_selected = [tool for tool in self.kb_tools if tool.name in selected_knowledges]
enabled_tools.extend(knowledge_tools_from_selected)
else:
enabled_tools.extend(self.kb_tools)

logger.info(
f"Dynamic tool selection: {len(enabled_tools)} tools enabled: {[tool.name for tool in enabled_tools]}, "
f"selected_tools: {selected_tools}, selected_mcps: {selected_mcps}"
f"selected_tools: {selected_tools}, "
f"selected_mcp_tools: {selected_mcp_tools}, "
f"selected_knowledges: {selected_knowledges}"
)

# 更新 request 中的工具列表
Expand Down
38 changes: 29 additions & 9 deletions src/agents/common/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,39 @@ def get_buildin_tools() -> list:
return static_tools


def get_basic_tools(tools: list[str] = None, extra_tools=None):
all_basic_tools = get_buildin_tools() + (extra_tools or [])
results = []

if tools:
# 创建工具映射表
tools_map = {t.name: t for t in all_basic_tools}
for tool_name in tools:
if tool_name in tools_map:
results.append(tools_map[tool_name])

return results


async def get_mcp_tools(server_names: list[str] = None):
results = []

if server_names:
for server_name in server_names:
mcp_tools = await get_enabled_mcp_tools(server_name)
results.extend(mcp_tools)

return results


async def get_tools_from_context(context, extra_tools=None) -> list:
"""从上下文配置中获取工具列表"""
# 1. 基础工具 (从 context.tools 中筛选)
all_basic_tools = get_buildin_tools() + (extra_tools or [])
selected_tools = []

if context.tools:
# 创建工具映射表
tools_map = {t.name: t for t in all_basic_tools}
for tool_name in context.tools:
if tool_name in tools_map:
selected_tools.append(tools_map[tool_name])
basic_tools = get_basic_tools(context.tools, extra_tools)
selected_tools.extend(basic_tools)

# 2. 知识库工具
if context.knowledges:
Expand All @@ -380,8 +401,7 @@ async def get_tools_from_context(context, extra_tools=None) -> list:

# 3. MCP 工具(使用统一入口,自动过滤 disabled_tools)
if context.mcps:
for server_name in context.mcps:
mcp_tools = await get_enabled_mcp_tools(server_name)
selected_tools.extend(mcp_tools)
mcp_tools = await get_mcp_tools(server_names=context.mcps)
selected_tools.extend(mcp_tools)

return selected_tools
4 changes: 1 addition & 3 deletions src/agents/reporter/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from langchain.agents import create_agent

from src.agents.common import BaseAgent, BaseContext, load_chat_model
from src.agents.common.tools import gen_tool_info, get_buildin_tools
from src.agents.common.toolkits.mysql import get_mysql_tools
from src.agents.common.tools import get_tools_from_context
from src.agents.common.tools import gen_tool_info, get_buildin_tools, get_tools_from_context
from src.utils import logger


@dataclass(kw_only=True)
class ReporterContext(BaseContext):

# 覆盖默认的工具列表,添加 MySQL 工具包
tools: Annotated[list[dict], {"__template_metadata__": {"kind": "tools"}}] = field(
default_factory=lambda: [t.name for t in get_mysql_tools()],
Expand Down
Loading