Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions astrbot/core/agent/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,21 @@ def empty(self) -> bool:
return len(self.tools) == 0

def add_tool(self, tool: FunctionTool) -> None:
"""Add a tool to the set."""
# 检查是否已存在同名工具
"""Add a tool to the set.

If a tool with the same name already exists:
- Prefer the one that is active (active=True)
- If both have the same active state, use the new one (overwrite)
"""
for i, existing_tool in enumerate(self.tools):
if existing_tool.name == tool.name:
self.tools[i] = tool
# Use getattr with a default of True for compatibility with tools
# that may not define an `active` attribute (e.g., mocks).
existing_active = bool(getattr(existing_tool, "active", True))
new_active = bool(getattr(tool, "active", True))
# Overwrite if new tool is active, or if existing tool is not active
if new_active or not existing_active:
self.tools[i] = tool
return
self.tools.append(tool)

Expand Down
23 changes: 20 additions & 3 deletions astrbot/core/provider/func_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,30 @@ def remove_func(self, name: str) -> None:
break

def get_func(self, name) -> FuncTool | None:
for f in self.func_list:
# 优先返回已激活的工具(后加载的覆盖前面的,与 ToolSet.add_tool 保持一致)
# 使用 getattr(..., True) 与 ToolSet.add_tool 保持一致:没有 active 属性的工具视为已激活
for f in reversed(self.func_list):
if f.name == name and getattr(f, "active", True):
return f
# 退化则拿最后一个同名工具
for f in reversed(self.func_list):
if f.name == name:
return f
return None

def get_full_tool_set(self) -> ToolSet:
"""获取完整工具集"""
tool_set = ToolSet(self.func_list.copy())
"""获取完整工具集

使用 ToolSet.add_tool 进行填充。对于同名工具,去重规则为:
- 优先保留 active=True 的工具;
- 当 active 状态相同时,后加载的工具会覆盖前面的工具。

因此,后加载的 inactive 工具不会覆盖已激活的工具;
同时,MCP 工具在需要时仍可覆盖被禁用的内置工具。
"""
tool_set = ToolSet()
for tool in self.func_list:
tool_set.add_tool(tool)
return tool_set

@staticmethod
Expand Down
233 changes: 233 additions & 0 deletions tests/unit/test_tool_conflict_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""Tests for tool conflict resolution in ToolSet.add_tool and FunctionToolManager.

This module tests the fix for issue #5821: when an MCP external tool shares a name
with a disabled built-in tool, the MCP tool should not be removed as collateral damage.
"""

import pytest

from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.provider.func_tool_manager import FunctionToolManager


def make_tool(name: str, active: bool = True) -> FunctionTool:
"""Create a simple FunctionTool for testing."""
return FunctionTool(
name=name,
description=f"Test tool {name}",
parameters={"type": "object", "properties": {}},
active=active,
)


class TestToolSetAddTool:
"""Tests for ToolSet.add_tool conflict resolution."""

def test_new_tool_active_existing_inactive_overwrites(self):
"""Active tool should overwrite inactive tool with same name."""
toolset = ToolSet()
toolset.add_tool(make_tool("web_search", active=False))
toolset.add_tool(make_tool("web_search", active=True))

assert len(toolset.tools) == 1
assert toolset.tools[0].active is True

def test_new_tool_inactive_existing_active_preserves_existing(self):
"""Inactive tool should NOT overwrite active tool with same name."""
toolset = ToolSet()
toolset.add_tool(make_tool("web_search", active=True))
toolset.add_tool(make_tool("web_search", active=False))

assert len(toolset.tools) == 1
assert toolset.tools[0].active is True

def test_both_active_last_one_wins(self):
"""When both tools are active, the new one should overwrite."""
toolset = ToolSet()
first = make_tool("web_search", active=True)
second = make_tool("web_search", active=True)
second.description = "Second web search"

toolset.add_tool(first)
toolset.add_tool(second)

assert len(toolset.tools) == 1
# The second tool should be the one kept
assert toolset.tools[0] is second
assert toolset.tools[0].description == "Second web search"

def test_both_inactive_last_one_wins(self):
"""When both tools are inactive, the new one should overwrite."""
toolset = ToolSet()
toolset.add_tool(make_tool("web_search", active=False))
toolset.add_tool(make_tool("web_search", active=False))

assert len(toolset.tools) == 1

def test_different_names_both_added(self):
"""Tools with different names should both be added."""
toolset = ToolSet()
toolset.add_tool(make_tool("web_search"))
toolset.add_tool(make_tool("code_search"))

assert len(toolset.tools) == 2

def test_missing_active_attribute_defaults_to_true(self):
"""Tools without 'active' attribute should be treated as active."""
toolset = ToolSet()

# Create a mock object without 'active' attribute
class MockTool:
name = "mock_tool"
description = "Mock"
parameters = {"type": "object"}

mock_tool = MockTool()
toolset.add_tool(mock_tool) # type: ignore

# Should be added successfully
assert len(toolset.tools) == 1

# Adding another tool without active should overwrite
mock_tool2 = MockTool()
toolset.add_tool(mock_tool2) # type: ignore

assert len(toolset.tools) == 1


class TestFunctionToolManagerGetFunc:
"""Tests for FunctionToolManager.get_func with conflict resolution."""

def test_returns_last_active_tool(self):
"""Should return the last active tool when multiple have same name."""
manager = FunctionToolManager()
manager.func_list = [
make_tool("web_search", active=True),
make_tool("web_search", active=True),
]

result = manager.get_func("web_search")
assert result is not None
# Should return the last one (reversed order)
assert result is manager.func_list[1]

def test_returns_active_over_inactive(self):
"""Should prefer active tool over inactive tool with same name."""
manager = FunctionToolManager()
manager.func_list = [
make_tool("web_search", active=False),
make_tool("web_search", active=True),
]

result = manager.get_func("web_search")
assert result is not None
assert result.active is True
assert result is manager.func_list[1]

def test_inactive_cannot_override_active(self):
"""Inactive tool after active should not be returned."""
manager = FunctionToolManager()
manager.func_list = [
make_tool("web_search", active=True),
make_tool("web_search", active=False),
]

result = manager.get_func("web_search")
assert result is not None
assert result.active is True
assert result is manager.func_list[0]

def test_fallback_to_last_when_none_active(self):
"""Should return last tool with matching name when none are active."""
manager = FunctionToolManager()
manager.func_list = [
make_tool("web_search", active=False),
make_tool("web_search", active=False),
]

result = manager.get_func("web_search")
assert result is not None
# Should return the last one (reversed order in fallback)
assert result is manager.func_list[1]

def test_returns_none_when_not_found(self):
"""Should return None when tool name not found."""
manager = FunctionToolManager()
manager.func_list = [make_tool("other_tool")]

result = manager.get_func("web_search")
assert result is None


class TestFunctionToolManagerGetFullToolSet:
"""Tests for FunctionToolManager.get_full_tool_set."""

def test_deduplicates_by_name_using_add_tool(self):
"""Should deduplicate tools using add_tool logic."""
manager = FunctionToolManager()
manager.func_list = [
make_tool("web_search", active=False),
make_tool("web_search", active=True),
make_tool("code_search", active=True),
]

toolset = manager.get_full_tool_set()

# Should have 2 tools after deduplication
assert len(toolset.tools) == 2
# web_search should be active (the MCP version)
web_search = toolset.get_tool("web_search")
assert web_search is not None
assert web_search.active is True

def test_no_deepcopy_preserves_identity(self):
"""Should not deep copy tools, preserving object identity."""
manager = FunctionToolManager()
tool = make_tool("web_search")
manager.func_list = [tool]

toolset = manager.get_full_tool_set()

# Same object reference (no deepcopy)
assert toolset.tools[0] is tool

def test_mcp_tool_overrides_disabled_builtin(self):
"""
Integration test: MCP tool should override disabled built-in tool.
This is the core scenario for issue #5821.
"""
manager = FunctionToolManager()
# Simulate: built-in tool registered first (disabled)
# Then MCP tool registered (enabled)
manager.func_list = [
make_tool("web_search", active=False), # Built-in, disabled
make_tool("web_search", active=True), # MCP, enabled
]

# get_func should return the MCP tool (active one)
result = manager.get_func("web_search")
assert result is not None
assert result.active is True
assert result is manager.func_list[1]

# get_full_tool_set should also keep the MCP tool
toolset = manager.get_full_tool_set()
assert len(toolset.tools) == 1
assert toolset.tools[0].active is True

def test_disabled_mcp_cannot_override_enabled_builtin(self):
"""Disabled MCP tool should not override enabled built-in tool."""
manager = FunctionToolManager()
manager.func_list = [
make_tool("web_search", active=True), # Built-in, enabled
make_tool("web_search", active=False), # MCP, disabled
]

result = manager.get_func("web_search")
assert result is not None
assert result.active is True
assert result is manager.func_list[0]

toolset = manager.get_full_tool_set()
assert len(toolset.tools) == 1
assert toolset.tools[0].active is True