diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index c149792dcd..b75999ea65 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -1,4 +1,5 @@ import asyncio +import copy import logging import os import re @@ -6,7 +7,7 @@ from contextlib import AsyncExitStack from datetime import timedelta from pathlib import Path, PureWindowsPath -from typing import Generic +from typing import Any, Generic from tenacity import ( before_sleep_log, @@ -325,6 +326,61 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"{e!s}" +def _normalize_mcp_input_schema(schema: dict[str, Any]) -> dict[str, Any]: + """Normalize common non-standard MCP JSON Schema variants. + + Some MCP servers incorrectly mark required properties with a boolean + `required: true` on the property schema itself. Draft 2020-12 requires the + parent object to declare `required` as an array of property names instead. + We lift those booleans to the parent object so the schema remains usable + without disabling validation entirely. + """ + + def _normalize(node: Any) -> Any: + if isinstance(node, list): + return [_normalize(item) for item in node] + + if not isinstance(node, dict): + return node + + normalized = {key: _normalize(value) for key, value in node.items()} + + properties = normalized.get("properties") + if isinstance(properties, dict): + original_properties = ( + node.get("properties") + if isinstance(node.get("properties"), dict) + else {} + ) + required = normalized.get("required") + required_list = required[:] if isinstance(required, list) else [] + + for prop_name, prop_schema in properties.items(): + if not isinstance(prop_schema, dict): + continue + + original_prop_schema = original_properties.get(prop_name, {}) + prop_required = ( + original_prop_schema.get("required") + if isinstance(original_prop_schema, dict) + else None + ) + if isinstance(prop_required, bool): + if prop_schema.get("required") is prop_required: + prop_schema.pop("required", None) + if prop_required: + required_list.append(prop_name) + + if required_list: + normalized["required"] = list(dict.fromkeys(required_list)) + elif isinstance(required, list): + normalized.pop("required", None) + + return normalized + + return _normalize(copy.deepcopy(schema)) + + class MCPClient: def __init__(self) -> None: # Initialize session and client objects @@ -602,7 +658,7 @@ def __init__( super().__init__( name=mcp_tool.name, description=mcp_tool.description or "", - parameters=mcp_tool.inputSchema, + parameters=_normalize_mcp_input_schema(mcp_tool.inputSchema), ) self.mcp_tool = mcp_tool self.mcp_client = mcp_client diff --git a/tests/unit/test_mcp_client_schema.py b/tests/unit/test_mcp_client_schema.py new file mode 100644 index 0000000000..0c3d9bc6ae --- /dev/null +++ b/tests/unit/test_mcp_client_schema.py @@ -0,0 +1,117 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from astrbot.core.agent.mcp_client import MCPTool, _normalize_mcp_input_schema + + +class TestNormalizeMcpInputSchema: + def test_lifts_property_level_required_booleans_to_parent_required_array(self): + schema = { + "type": "object", + "properties": { + "stock_code": {"type": "string", "required": True}, + "market": {"type": "string", "required": False}, + }, + } + + normalized = _normalize_mcp_input_schema(schema) + + assert normalized["required"] == ["stock_code"] + assert "required" not in normalized["properties"]["stock_code"] + assert "required" not in normalized["properties"]["market"] + assert schema["properties"]["stock_code"]["required"] is True + + def test_preserves_existing_required_arrays_while_fixing_nested_objects(self): + schema = { + "type": "object", + "required": ["server"], + "properties": { + "server": { + "type": "object", + "required": ["transport"], + "properties": { + "transport": {"type": "string"}, + "stock_code": {"type": "string", "required": True}, + "market": {"type": "string", "required": False}, + }, + } + }, + } + + normalized = _normalize_mcp_input_schema(schema) + + assert normalized["required"] == ["server"] + assert normalized["properties"]["server"]["required"] == [ + "transport", + "stock_code", + ] + assert ( + "required" + not in normalized["properties"]["server"]["properties"]["stock_code"] + ) + assert ( + "required" not in normalized["properties"]["server"]["properties"]["market"] + ) + + def test_preserves_parent_required_flag_for_nested_object_properties(self): + schema = { + "type": "object", + "properties": { + "server": { + "type": "object", + "required": True, + "properties": { + "transport": {"type": "string", "required": True}, + }, + } + }, + } + + normalized = _normalize_mcp_input_schema(schema) + + assert normalized["required"] == ["server"] + assert normalized["properties"]["server"]["required"] == ["transport"] + assert ( + "required" + not in normalized["properties"]["server"]["properties"]["transport"] + ) + + def test_ignores_non_boolean_required_values_and_non_dict_properties(self): + schema = { + "type": "object", + "properties": { + "server": "invalid-property-schema", + "market": {"type": "string", "required": "yes"}, + "stock_code": {"type": "string", "required": True}, + }, + } + + normalized = _normalize_mcp_input_schema(schema) + + assert normalized["required"] == ["stock_code"] + assert normalized["properties"]["server"] == "invalid-property-schema" + assert normalized["properties"]["market"]["required"] == "yes" + assert "required" not in normalized["properties"]["stock_code"] + assert schema["properties"]["server"] == "invalid-property-schema" + assert schema["properties"]["market"]["required"] == "yes" + + +class TestMCPToolSchemaNormalization: + def test_mcp_tool_accepts_property_level_required_booleans(self): + mcp_tool = SimpleNamespace( + name="quote_lookup", + description="Lookup a quote", + inputSchema={ + "type": "object", + "properties": { + "stock_code": {"type": "string", "required": True}, + "market": {"type": "string", "required": False}, + }, + }, + ) + + tool = MCPTool(mcp_tool, MagicMock(), "gf-securities") + + assert tool.parameters["required"] == ["stock_code"] + assert "required" not in tool.parameters["properties"]["stock_code"] + assert "required" not in tool.parameters["properties"]["market"]