From eb65e73c70353441c6a3c741e603770e71df4bf6 Mon Sep 17 00:00:00 2001 From: wjiajian <71640797+wjiajian@users.noreply.github.com> Date: Fri, 24 Apr 2026 14:50:48 +0800 Subject: [PATCH 1/7] feat: add Firecrawl web search and extract tools, update configuration and tests --- astrbot/core/astr_main_agent.py | 5 + astrbot/core/config/default.py | 11 + astrbot/core/tools/web_search_tools.py | 200 ++++++++++++++++++ .../components/chat/MessageListDEPRECATED.vue | 2 +- .../en-US/features/config-metadata.json | 4 + .../ru-RU/features/config-metadata.json | 4 + .../zh-CN/features/config-metadata.json | 4 + tests/unit/test_astr_main_agent.py | 31 +++ tests/unit/test_func_tool_manager.py | 14 ++ tests/unit/test_web_search_tools.py | 95 +++++++++ 10 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_web_search_tools.py diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index a9ddb2e7b9..87cb2db064 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -77,6 +77,8 @@ BaiduWebSearchTool, BochaWebSearchTool, BraveWebSearchTool, + FirecrawlExtractWebPageTool, + FirecrawlWebSearchTool, TavilyExtractWebPageTool, TavilyWebSearchTool, normalize_legacy_web_search_config, @@ -1047,6 +1049,9 @@ async def _apply_web_search_tools( req.func_tool.add_tool(tool_mgr.get_builtin_tool(BochaWebSearchTool)) elif provider == "brave": req.func_tool.add_tool(tool_mgr.get_builtin_tool(BraveWebSearchTool)) + elif provider == "firecrawl": + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlWebSearchTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlExtractWebPageTool)) elif provider == "baidu_ai_search": req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool)) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 903f6c445f..cd1c81a888 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -3202,6 +3202,7 @@ class ChatProviderTemplate(TypedDict): "baidu_ai_search", "bocha", "brave", + "firecrawl", ], "condition": { "provider_settings.web_search": True, @@ -3237,6 +3238,16 @@ class ChatProviderTemplate(TypedDict): "provider_settings.web_search": True, }, }, + "provider_settings.websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "type": "list", + "items": {"type": "string"}, + "hint": "可添加多个 Key 进行轮询。", + "condition": { + "provider_settings.websearch_provider": "firecrawl", + "provider_settings.web_search": True, + }, + }, "provider_settings.websearch_baidu_app_builder_key": { "description": "百度千帆智能云 APP Builder API Key", "type": "string", diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index ca89bc17d4..a1456f4c8f 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -19,6 +19,8 @@ "tavily_extract_web_page", "web_search_bocha", "web_search_brave", + "web_search_firecrawl", + "firecrawl_extract_web_page", ] _TAVILY_WEB_SEARCH_TOOL_CONFIG = { "provider_settings.web_search": True, @@ -32,6 +34,10 @@ "provider_settings.web_search": True, "provider_settings.websearch_provider": "brave", } +_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG = { + "provider_settings.web_search": True, + "provider_settings.websearch_provider": "firecrawl", +} _BAIDU_WEB_SEARCH_TOOL_CONFIG = { "provider_settings.web_search": True, "provider_settings.websearch_provider": "baidu_ai_search", @@ -69,6 +75,7 @@ async def get(self, provider_settings: dict) -> str: _TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily") _BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha") _BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave") +_FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") def normalize_legacy_web_search_config(cfg) -> None: @@ -91,6 +98,7 @@ def normalize_legacy_web_search_config(cfg) -> None: "websearch_tavily_key", "websearch_bocha_key", "websearch_brave_key", + "websearch_firecrawl_key", ): value = provider_settings.get(setting_name) if isinstance(value, str): @@ -258,6 +266,70 @@ async def _brave_search( ] +async def _firecrawl_search( + provider_settings: dict, + payload: dict, +) -> list[SearchResult]: + firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) + headers = { + "Authorization": f"Bearer {firecrawl_key}", + "Content-Type": "application/json", + } + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.firecrawl.dev/v2/search", + json=payload, + headers=headers, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Firecrawl web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + rows = data.get("data", {}).get("web", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=( + item.get("description") + or item.get("snippet") + or item.get("markdown") + or "" + ), + ) + for item in rows + if item.get("url") + ] + + +async def _firecrawl_scrape(provider_settings: dict, payload: dict) -> dict: + firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) + headers = { + "Authorization": f"Bearer {firecrawl_key}", + "Content-Type": "application/json", + } + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.firecrawl.dev/v2/scrape", + json=payload, + headers=headers, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Firecrawl web scrape failed: {reason}, status: {response.status}", + ) + data = await response.json() + result = data.get("data", {}) + if not result: + raise ValueError( + "Error: Firecrawl web scraper does not return any results." + ) + return result + + async def _baidu_search( provider_settings: dict, payload: dict, @@ -548,6 +620,134 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) +@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG) +@pydantic_dataclass +class FirecrawlWebSearchTool(FunctionTool[AstrAgentContext]): + name: str = "web_search_firecrawl" + description: str = ( + "A web search tool based on Firecrawl Search API, used to retrieve web " + "pages related to the user's query." + ) + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Required. Search query."}, + "limit": { + "type": "integer", + "description": "Optional. Number of results to return. Range: 1-100. Default is 5.", + }, + "tbs": { + "type": "string", + "description": 'Optional. Time-based search filter, for example "qdr:d" or "qdr:w".', + }, + "location": { + "type": "string", + "description": "Optional. Geographic location for search results.", + }, + "country": { + "type": "string", + "description": 'Optional. Country code for search results, for example "US" or "CN".', + }, + "timeout": { + "type": "integer", + "description": "Optional. Request timeout in milliseconds.", + }, + }, + "required": ["query"], + } + ) + + async def call(self, context, **kwargs) -> ToolExecResult: + _, provider_settings, _ = _get_runtime(context) + if not provider_settings.get("websearch_firecrawl_key", []): + return "Error: Firecrawl API key is not configured in AstrBot." + + limit = int(kwargs.get("limit", 5)) + if limit < 1: + limit = 1 + if limit > 100: + limit = 100 + + payload = { + "query": kwargs["query"], + "limit": limit, + "sources": ["web"], + } + for key in ("tbs", "location", "country", "timeout"): + if kwargs.get(key): + payload[key] = kwargs[key] + + results = await _firecrawl_search(provider_settings, payload) + if not results: + return "Error: Firecrawl web searcher does not return any results." + return _search_result_payload(results) + + +@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG) +@pydantic_dataclass +class FirecrawlExtractWebPageTool(FunctionTool[AstrAgentContext]): + name: str = "firecrawl_extract_web_page" + description: str = "Extract the content of a web page using Firecrawl." + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "Required. A URL to extract content from.", + }, + "format": { + "type": "string", + "description": 'Optional. Output format, one of "markdown", "html", "rawHtml", "summary". Default is "markdown".', + }, + "only_main_content": { + "type": "boolean", + "description": "Optional. Whether to extract only the main page content. Default is true.", + }, + "timeout": { + "type": "integer", + "description": "Optional. Request timeout in milliseconds.", + }, + "max_age": { + "type": "integer", + "description": "Optional. Maximum cache age in milliseconds.", + }, + }, + "required": ["url"], + } + ) + + async def call(self, context, **kwargs) -> ToolExecResult: + _, provider_settings, _ = _get_runtime(context) + if not provider_settings.get("websearch_firecrawl_key", []): + return "Error: Firecrawl API key is not configured in AstrBot." + + url = str(kwargs.get("url", "")).strip() + if not url: + return "Error: url must be a non-empty string." + + output_format = kwargs.get("format", "markdown") + if output_format not in ["markdown", "html", "rawHtml", "summary"]: + output_format = "markdown" + + payload = { + "url": url, + "formats": [output_format], + "onlyMainContent": kwargs.get("only_main_content", True), + } + if kwargs.get("timeout"): + payload["timeout"] = kwargs["timeout"] + if kwargs.get("max_age"): + payload["maxAge"] = kwargs["max_age"] + + result = await _firecrawl_scrape(provider_settings, payload) + content = result.get(output_format, "") + result_url = result.get("url") or url + ret = f"URL: {result_url}\nContent: {content}" if content else "" + return ret or "Error: Firecrawl web scraper does not return any results." + + @builtin_tool(config=_BAIDU_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class BaiduWebSearchTool(FunctionTool[AstrAgentContext]): diff --git a/dashboard/src/components/chat/MessageListDEPRECATED.vue b/dashboard/src/components/chat/MessageListDEPRECATED.vue index 128d0d97b1..1271042b39 100644 --- a/dashboard/src/components/chat/MessageListDEPRECATED.vue +++ b/dashboard/src/components/chat/MessageListDEPRECATED.vue @@ -303,7 +303,7 @@ export default { part.tool_calls.forEach(toolCall => { // 检查是否是支持引用解析的 web_search 工具调用 if ( - !['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave'].includes(toolCall.name) || + !['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave', 'web_search_firecrawl'].includes(toolCall.name) || !toolCall.result ) { return; diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 373182fc15..4f35dd2859 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -125,6 +125,10 @@ "description": "Brave Search API Key", "hint": "Multiple keys can be added for rotation." }, + "websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "hint": "Multiple keys can be added for rotation." + }, "websearch_baidu_app_builder_key": { "description": "Baidu Qianfan Smart Cloud APP Builder API Key", "hint": "Reference: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index c578f79d1c..08d11aed6a 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -125,6 +125,10 @@ "description": "API-ключ Brave Search", "hint": "Можно добавить несколько ключей для ротации." }, + "websearch_firecrawl_key": { + "description": "API-ключ Firecrawl", + "hint": "Можно добавить несколько ключей для ротации." + }, "websearch_baidu_app_builder_key": { "description": "API-ключ Baidu Qianfan APP Builder", "hint": "Ссылка: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index dd8711345c..8495f9ba1a 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -127,6 +127,10 @@ "description": "Brave Search API Key", "hint": "可添加多个 Key 进行轮询。" }, + "websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "hint": "可添加多个 Key 进行轮询。" + }, "websearch_baidu_app_builder_key": { "description": "百度千帆智能云 APP Builder API Key", "hint": "参考:[https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 6ca1a3f2aa..5a5bceae15 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -398,6 +398,37 @@ async def test_apply_web_search_tools_uses_builtin_tool_manager( assert req.func_tool is not None assert req.func_tool.get_tool("web_search_baidu") is builtin_tool + @pytest.mark.asyncio + async def test_apply_web_search_tools_adds_firecrawl_search_and_extract_tools( + self, mock_event, mock_context + ): + """Test Firecrawl web search injects search and extract tools.""" + module = ama + req = ProviderRequest() + mock_context.get_config.return_value = { + "provider_settings": { + "web_search": True, + "websearch_provider": "firecrawl", + } + } + search_tool = MagicMock(spec=FunctionTool) + search_tool.name = "web_search_firecrawl" + extract_tool = MagicMock(spec=FunctionTool) + extract_tool.name = "firecrawl_extract_web_page" + tool_mgr = MagicMock() + tool_mgr.get_builtin_tool.side_effect = [search_tool, extract_tool] + mock_context.get_llm_tool_manager.return_value = tool_mgr + + await module._apply_web_search_tools(mock_event, req, mock_context) + + assert tool_mgr.get_builtin_tool.call_args_list == [ + ((module.FirecrawlWebSearchTool,),), + ((module.FirecrawlExtractWebPageTool,),), + ] + assert req.func_tool is not None + assert req.func_tool.get_tool("web_search_firecrawl") is search_tool + assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool + def test_proactive_cron_job_tools_uses_builtin_tool_manager(self, mock_context): """Test cron tool injection through the builtin tool manager.""" module = ama diff --git a/tests/unit/test_func_tool_manager.py b/tests/unit/test_func_tool_manager.py index 908810cdb3..c87a2de085 100644 --- a/tests/unit/test_func_tool_manager.py +++ b/tests/unit/test_func_tool_manager.py @@ -2,6 +2,8 @@ from astrbot.core.provider.func_tool_manager import FunctionToolManager from astrbot.core.tools.computer_tools.shell import ExecuteShellTool from astrbot.core.tools.message_tools import SendMessageToUserTool +from astrbot.core.tools.web_search_tools import FirecrawlExtractWebPageTool +from astrbot.core.tools.web_search_tools import FirecrawlWebSearchTool def test_get_builtin_tool_by_class_returns_cached_instance(): @@ -38,3 +40,15 @@ def test_computer_tools_are_registered_as_builtin_tools(): assert tool.name == "astrbot_execute_shell" assert manager.is_builtin_tool("astrbot_execute_shell") is True + + +def test_firecrawl_tools_are_registered_as_builtin_tools(): + manager = FunctionToolManager() + + search_tool = manager.get_builtin_tool(FirecrawlWebSearchTool) + extract_tool = manager.get_builtin_tool(FirecrawlExtractWebPageTool) + + assert search_tool.name == "web_search_firecrawl" + assert extract_tool.name == "firecrawl_extract_web_page" + assert manager.is_builtin_tool("web_search_firecrawl") is True + assert manager.is_builtin_tool("firecrawl_extract_web_page") is True diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py new file mode 100644 index 0000000000..930b8638a8 --- /dev/null +++ b/tests/unit/test_web_search_tools.py @@ -0,0 +1,95 @@ +import json +from types import SimpleNamespace + +import pytest + +from astrbot.core.tools import web_search_tools as tools + + +class _FakeConfig(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.saved = False + + def save_config(self): + self.saved = True + + +def test_normalize_legacy_web_search_config_migrates_firecrawl_key(): + config = _FakeConfig( + {"provider_settings": {"websearch_firecrawl_key": "firecrawl-key"}} + ) + + tools.normalize_legacy_web_search_config(config) + + assert config["provider_settings"]["websearch_firecrawl_key"] == [ + "firecrawl-key" + ] + assert config.saved is True + + +@pytest.mark.asyncio +async def test_firecrawl_search_maps_web_results(monkeypatch): + async def fake_firecrawl_search(provider_settings, payload): + assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] + assert payload == { + "query": "AstrBot", + "limit": 3, + "sources": ["web"], + "country": "US", + } + return [ + tools.SearchResult( + title="AstrBot", + url="https://example.com", + snippet="Search result", + ) + ] + + monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search) + tool = tools.FirecrawlWebSearchTool() + context = _context_with_provider_settings( + {"websearch_firecrawl_key": ["firecrawl-key"]} + ) + + result = await tool.call(context, query="AstrBot", limit=3, country="US") + + assert json.loads(result)["results"] == [ + { + "title": "AstrBot", + "url": "https://example.com", + "snippet": "Search result", + "index": json.loads(result)["results"][0]["index"], + } + ] + + +@pytest.mark.asyncio +async def test_firecrawl_extract_returns_scraped_markdown(monkeypatch): + async def fake_firecrawl_scrape(provider_settings, payload): + assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] + assert payload == { + "url": "https://example.com", + "formats": ["markdown"], + "onlyMainContent": True, + } + return {"url": "https://example.com", "markdown": "# Example"} + + monkeypatch.setattr(tools, "_firecrawl_scrape", fake_firecrawl_scrape) + tool = tools.FirecrawlExtractWebPageTool() + context = _context_with_provider_settings( + {"websearch_firecrawl_key": ["firecrawl-key"]} + ) + + result = await tool.call(context, url="https://example.com") + + assert result == "URL: https://example.com\nContent: # Example" + + +def _context_with_provider_settings(provider_settings): + config = {"provider_settings": provider_settings} + agent_context = SimpleNamespace( + context=SimpleNamespace(get_config=lambda umo: config), + event=SimpleNamespace(unified_msg_origin="test:private:session"), + ) + return SimpleNamespace(context=agent_context) From dae5e15dd0a91361830b6df1a32980bf3c6074c5 Mon Sep 17 00:00:00 2001 From: wjiajian <71640797+wjiajian@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:44:25 +0800 Subject: [PATCH 2/7] feat: implement Firecrawl API integration and error handling in web search tools --- astrbot/core/tools/web_search_tools.py | 77 ++++++++++------------ tests/unit/test_web_search_tools.py | 88 +++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 45 deletions(-) diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index a1456f4c8f..339bca3cd7 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -76,6 +76,11 @@ async def get(self, provider_settings: dict) -> str: _BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha") _BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave") _FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") +_FIRECRAWL_BASE_URL = "https://api.firecrawl.dev/v2" + + +class FirecrawlAPIError(RuntimeError): + pass def normalize_legacy_web_search_config(cfg) -> None: @@ -270,41 +275,35 @@ async def _firecrawl_search( provider_settings: dict, payload: dict, ) -> list[SearchResult]: - firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) - headers = { - "Authorization": f"Bearer {firecrawl_key}", - "Content-Type": "application/json", - } - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.post( - "https://api.firecrawl.dev/v2/search", - json=payload, - headers=headers, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Firecrawl web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - rows = data.get("data", {}).get("web", []) - return [ - SearchResult( - title=item.get("title", ""), - url=item.get("url", ""), - snippet=( - item.get("description") - or item.get("snippet") - or item.get("markdown") - or "" - ), - ) - for item in rows - if item.get("url") - ] + data = await _firecrawl_post(provider_settings, "search", payload) + rows = data.get("data", {}).get("web", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=( + item.get("description") + or item.get("snippet") + or item.get("markdown") + or "" + ), + ) + for item in rows + if item.get("url") + ] async def _firecrawl_scrape(provider_settings: dict, payload: dict) -> dict: + data = await _firecrawl_post(provider_settings, "scrape", payload) + result = data.get("data", {}) + if not result: + raise ValueError("Error: Firecrawl web scraper does not return any results.") + return result + + +async def _firecrawl_post( + provider_settings: dict, endpoint: str, payload: dict +) -> dict: firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) headers = { "Authorization": f"Bearer {firecrawl_key}", @@ -312,22 +311,16 @@ async def _firecrawl_scrape(provider_settings: dict, payload: dict) -> dict: } async with aiohttp.ClientSession(trust_env=True) as session: async with session.post( - "https://api.firecrawl.dev/v2/scrape", + f"{_FIRECRAWL_BASE_URL}/{endpoint}", json=payload, headers=headers, ) as response: if response.status != 200: reason = await response.text() - raise Exception( - f"Firecrawl web scrape failed: {reason}, status: {response.status}", - ) - data = await response.json() - result = data.get("data", {}) - if not result: - raise ValueError( - "Error: Firecrawl web scraper does not return any results." + raise FirecrawlAPIError( + f"Firecrawl {endpoint} failed: {reason}, status: {response.status}", ) - return result + return await response.json() async def _baidu_search( diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index 930b8638a8..e6f3e25c2b 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -22,9 +22,7 @@ def test_normalize_legacy_web_search_config_migrates_firecrawl_key(): tools.normalize_legacy_web_search_config(config) - assert config["provider_settings"]["websearch_firecrawl_key"] == [ - "firecrawl-key" - ] + assert config["provider_settings"]["websearch_firecrawl_key"] == ["firecrawl-key"] assert config.saved is True @@ -86,6 +84,90 @@ async def fake_firecrawl_scrape(provider_settings, payload): assert result == "URL: https://example.com\nContent: # Example" +@pytest.mark.asyncio +async def test_firecrawl_post_uses_shared_request_setup(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse(status=200, json_data={"success": True}) + ) + + def fake_client_session(trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + result = await tools._firecrawl_post( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + "search", + {"query": "AstrBot"}, + ) + + assert result == {"success": True} + assert session.trust_env is True + assert session.posted == { + "url": "https://api.firecrawl.dev/v2/search", + "json": {"query": "AstrBot"}, + "headers": { + "Authorization": "Bearer firecrawl-key", + "Content-Type": "application/json", + }, + } + + +@pytest.mark.asyncio +async def test_firecrawl_post_raises_api_error_for_http_errors(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse(status=401, text_data="Unauthorized") + ) + monkeypatch.setattr(tools.aiohttp, "ClientSession", lambda trust_env: session) + + with pytest.raises( + tools.FirecrawlAPIError, + match="Firecrawl scrape failed: Unauthorized, status: 401", + ): + await tools._firecrawl_post( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + "scrape", + {"url": "https://example.com"}, + ) + + +class _FakeFirecrawlResponse: + def __init__(self, status=200, json_data=None, text_data=""): + self.status = status + self.json_data = json_data or {} + self.text_data = text_data + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self): + return self.json_data + + async def text(self): + return self.text_data + + +class _FakeFirecrawlSession: + def __init__(self, response): + self.response = response + self.trust_env = None + self.posted = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def post(self, url, json, headers): + self.posted = {"url": url, "json": json, "headers": headers} + return self.response + + def _context_with_provider_settings(provider_settings): config = {"provider_settings": provider_settings} agent_context = SimpleNamespace( From 4b0deb37eb88e5ee76639b61fa6eda5e5dd76269 Mon Sep 17 00:00:00 2001 From: wjiajian <71640797+wjiajian@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:50:59 +0800 Subject: [PATCH 3/7] feat: enhance Firecrawl web search with session management and payload validation --- astrbot/core/tools/web_search_tools.py | 271 +++++++++++++------------ tests/unit/test_web_search_tools.py | 78 ++++++- 2 files changed, 216 insertions(+), 133 deletions(-) diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index 339bca3cd7..618e5c9016 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -77,12 +77,37 @@ async def get(self, provider_settings: dict) -> str: _BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave") _FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") _FIRECRAWL_BASE_URL = "https://api.firecrawl.dev/v2" +_WEB_SEARCH_SESSION: aiohttp.ClientSession | None = None +_WEB_SEARCH_SESSION_LOOP: asyncio.AbstractEventLoop | None = None class FirecrawlAPIError(RuntimeError): pass +async def _get_web_search_session() -> aiohttp.ClientSession: + global _WEB_SEARCH_SESSION, _WEB_SEARCH_SESSION_LOOP + loop = asyncio.get_running_loop() + if ( + _WEB_SEARCH_SESSION is None + or _WEB_SEARCH_SESSION_LOOP is not loop + or _WEB_SEARCH_SESSION.closed + ): + _WEB_SEARCH_SESSION = aiohttp.ClientSession(trust_env=True) + _WEB_SEARCH_SESSION_LOOP = loop + return _WEB_SEARCH_SESSION + + +def _coerce_int(value, default: int, minimum: int, maximum: int) -> int: + if value is None: + value = default + try: + coerced = int(value) + except (TypeError, ValueError): + coerced = default + return max(minimum, min(maximum, coerced)) + + def normalize_legacy_web_search_config(cfg) -> None: provider_settings = cfg.get("provider_settings") if not provider_settings: @@ -153,27 +178,27 @@ async def _tavily_search( "Authorization": f"Bearer {tavily_key}", "Content-Type": "application/json", } - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.post( - "https://api.tavily.com/search", - json=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Tavily web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - return [ - SearchResult( - title=item.get("title"), - url=item.get("url"), - snippet=item.get("content"), - favicon=item.get("favicon"), - ) - for item in data.get("results", []) - ] + session = await _get_web_search_session() + async with session.post( + "https://api.tavily.com/search", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Tavily web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + return [ + SearchResult( + title=item.get("title"), + url=item.get("url"), + snippet=item.get("content"), + favicon=item.get("favicon"), + ) + for item in data.get("results", []) + ] async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]: @@ -182,24 +207,22 @@ async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]: "Authorization": f"Bearer {tavily_key}", "Content-Type": "application/json", } - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.post( - "https://api.tavily.com/extract", - json=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Tavily web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - results: list[dict] = data.get("results", []) - if not results: - raise ValueError( - "Error: Tavily web searcher does not return any results." - ) - return results + session = await _get_web_search_session() + async with session.post( + "https://api.tavily.com/extract", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Tavily web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + results: list[dict] = data.get("results", []) + if not results: + raise ValueError("Error: Tavily web searcher does not return any results.") + return results async def _bocha_search( @@ -215,28 +238,28 @@ async def _bocha_search( # See: https://github.com/aio-libs/aiohttp/issues/11898 "Accept-Encoding": "gzip, deflate", } - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.post( - "https://api.bochaai.com/v1/web-search", - json=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"BoCha web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - rows = data["data"]["webPages"]["value"] - return [ - SearchResult( - title=item.get("name"), - url=item.get("url"), - snippet=item.get("snippet"), - favicon=item.get("siteIcon"), - ) - for item in rows - ] + session = await _get_web_search_session() + async with session.post( + "https://api.bochaai.com/v1/web-search", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"BoCha web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + rows = data["data"]["webPages"]["value"] + return [ + SearchResult( + title=item.get("name"), + url=item.get("url"), + snippet=item.get("snippet"), + favicon=item.get("siteIcon"), + ) + for item in rows + ] async def _brave_search( @@ -248,27 +271,27 @@ async def _brave_search( "Accept": "application/json", "X-Subscription-Token": brave_key, } - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get( - "https://api.search.brave.com/res/v1/web/search", - params=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Brave web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - rows = data.get("web", {}).get("results", []) - return [ - SearchResult( - title=item.get("title", ""), - url=item.get("url", ""), - snippet=item.get("description", ""), - ) - for item in rows - ] + session = await _get_web_search_session() + async with session.get( + "https://api.search.brave.com/res/v1/web/search", + params=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Brave web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + rows = data.get("web", {}).get("results", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=item.get("description", ""), + ) + for item in rows + ] async def _firecrawl_search( @@ -276,7 +299,7 @@ async def _firecrawl_search( payload: dict, ) -> list[SearchResult]: data = await _firecrawl_post(provider_settings, "search", payload) - rows = data.get("data", {}).get("web", []) + rows = data.get("data", []) return [ SearchResult( title=item.get("title", ""), @@ -309,18 +332,18 @@ async def _firecrawl_post( "Authorization": f"Bearer {firecrawl_key}", "Content-Type": "application/json", } - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.post( - f"{_FIRECRAWL_BASE_URL}/{endpoint}", - json=payload, - headers=headers, - ) as response: - if response.status != 200: - reason = await response.text() - raise FirecrawlAPIError( - f"Firecrawl {endpoint} failed: {reason}, status: {response.status}", - ) - return await response.json() + session = await _get_web_search_session() + async with session.post( + f"{_FIRECRAWL_BASE_URL}/{endpoint}", + json=payload, + headers=headers, + ) as response: + if response.status != 200: + reason = await response.text() + raise FirecrawlAPIError( + f"Firecrawl {endpoint} failed: {reason}, status: {response.status}", + ) + return await response.json() async def _baidu_search( @@ -336,29 +359,29 @@ async def _baidu_search( "X-Appbuilder-Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.post( - "https://qianfan.baidubce.com/v2/ai_search/web_search", - json=payload, - headers=headers, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Baidu AI Search failed: {reason}, status: {response.status}", - ) - data = await response.json() - references = data.get("references", []) - return [ - SearchResult( - title=item.get("title", ""), - url=item.get("url", ""), - snippet=item.get("content", ""), - favicon=item.get("icon"), - ) - for item in references - if item.get("url") - ] + session = await _get_web_search_session() + async with session.post( + "https://qianfan.baidubce.com/v2/ai_search/web_search", + json=payload, + headers=headers, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Baidu AI Search failed: {reason}, status: {response.status}", + ) + data = await response.json() + references = data.get("references", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=item.get("content", ""), + favicon=item.get("icon"), + ) + for item in references + if item.get("url") + ] @builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG) @@ -630,10 +653,6 @@ class FirecrawlWebSearchTool(FunctionTool[AstrAgentContext]): "type": "integer", "description": "Optional. Number of results to return. Range: 1-100. Default is 5.", }, - "tbs": { - "type": "string", - "description": 'Optional. Time-based search filter, for example "qdr:d" or "qdr:w".', - }, "location": { "type": "string", "description": "Optional. Geographic location for search results.", @@ -656,18 +675,14 @@ async def call(self, context, **kwargs) -> ToolExecResult: if not provider_settings.get("websearch_firecrawl_key", []): return "Error: Firecrawl API key is not configured in AstrBot." - limit = int(kwargs.get("limit", 5)) - if limit < 1: - limit = 1 - if limit > 100: - limit = 100 + limit = _coerce_int(kwargs.get("limit"), 5, 1, 100) payload = { "query": kwargs["query"], "limit": limit, "sources": ["web"], } - for key in ("tbs", "location", "country", "timeout"): + for key in ("location", "country", "timeout"): if kwargs.get(key): payload[key] = kwargs[key] diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index e6f3e25c2b..3c97e5c4ad 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -62,6 +62,72 @@ async def fake_firecrawl_search(provider_settings, payload): ] +@pytest.mark.asyncio +async def test_firecrawl_search_maps_v2_data_list(monkeypatch): + async def fake_firecrawl_post(provider_settings, endpoint, payload): + assert endpoint == "search" + assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] + assert payload == {"query": "AstrBot", "limit": 5, "sources": ["web"]} + return { + "success": True, + "data": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ], + } + + monkeypatch.setattr(tools, "_firecrawl_post", fake_firecrawl_post) + + results = await tools._firecrawl_search( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"query": "AstrBot", "limit": 5, "sources": ["web"]}, + ) + + assert results == [ + tools.SearchResult( + title="AstrBot", url="https://example.com", snippet="Search result" + ) + ] + + +@pytest.mark.asyncio +async def test_firecrawl_search_payload_omits_tbs_and_handles_null_limit(monkeypatch): + async def fake_firecrawl_search(provider_settings, payload): + assert payload == { + "query": "AstrBot", + "limit": 5, + "sources": ["web"], + "country": "US", + } + return [ + tools.SearchResult( + title="AstrBot", + url="https://example.com", + snippet="Search result", + ) + ] + + monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search) + tool = tools.FirecrawlWebSearchTool() + context = _context_with_provider_settings( + {"websearch_firecrawl_key": ["firecrawl-key"]} + ) + + result = await tool.call( + context, + query="AstrBot", + limit=None, + tbs="qdr:d", + country="US", + ) + + assert json.loads(result)["results"][0]["url"] == "https://example.com" + assert "tbs" not in tool.parameters["properties"] + + @pytest.mark.asyncio async def test_firecrawl_extract_returns_scraped_markdown(monkeypatch): async def fake_firecrawl_scrape(provider_settings, payload): @@ -90,11 +156,10 @@ async def test_firecrawl_post_uses_shared_request_setup(monkeypatch): _FakeFirecrawlResponse(status=200, json_data={"success": True}) ) - def fake_client_session(trust_env): - session.trust_env = trust_env + async def fake_get_web_search_session(): return session - monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + monkeypatch.setattr(tools, "_get_web_search_session", fake_get_web_search_session) result = await tools._firecrawl_post( {"websearch_firecrawl_key": ["firecrawl-key"]}, @@ -103,7 +168,6 @@ def fake_client_session(trust_env): ) assert result == {"success": True} - assert session.trust_env is True assert session.posted == { "url": "https://api.firecrawl.dev/v2/search", "json": {"query": "AstrBot"}, @@ -119,7 +183,11 @@ async def test_firecrawl_post_raises_api_error_for_http_errors(monkeypatch): session = _FakeFirecrawlSession( _FakeFirecrawlResponse(status=401, text_data="Unauthorized") ) - monkeypatch.setattr(tools.aiohttp, "ClientSession", lambda trust_env: session) + + async def fake_get_web_search_session(): + return session + + monkeypatch.setattr(tools, "_get_web_search_session", fake_get_web_search_session) with pytest.raises( tools.FirecrawlAPIError, From cc19cc2d30b8d67ca276892329354eb8e6d6401e Mon Sep 17 00:00:00 2001 From: wjiajian <71640797+wjiajian@users.noreply.github.com> Date: Sat, 25 Apr 2026 17:39:25 +0800 Subject: [PATCH 4/7] feat: Firecrawl web search to use aiohttp.ClientSession directly for improved session management as it was --- astrbot/core/tools/web_search_tools.py | 247 ++++++++++++------------- tests/unit/test_web_search_tools.py | 21 ++- 2 files changed, 134 insertions(+), 134 deletions(-) diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index 618e5c9016..6f83abaf35 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -77,27 +77,12 @@ async def get(self, provider_settings: dict) -> str: _BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave") _FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") _FIRECRAWL_BASE_URL = "https://api.firecrawl.dev/v2" -_WEB_SEARCH_SESSION: aiohttp.ClientSession | None = None -_WEB_SEARCH_SESSION_LOOP: asyncio.AbstractEventLoop | None = None class FirecrawlAPIError(RuntimeError): pass -async def _get_web_search_session() -> aiohttp.ClientSession: - global _WEB_SEARCH_SESSION, _WEB_SEARCH_SESSION_LOOP - loop = asyncio.get_running_loop() - if ( - _WEB_SEARCH_SESSION is None - or _WEB_SEARCH_SESSION_LOOP is not loop - or _WEB_SEARCH_SESSION.closed - ): - _WEB_SEARCH_SESSION = aiohttp.ClientSession(trust_env=True) - _WEB_SEARCH_SESSION_LOOP = loop - return _WEB_SEARCH_SESSION - - def _coerce_int(value, default: int, minimum: int, maximum: int) -> int: if value is None: value = default @@ -178,27 +163,27 @@ async def _tavily_search( "Authorization": f"Bearer {tavily_key}", "Content-Type": "application/json", } - session = await _get_web_search_session() - async with session.post( - "https://api.tavily.com/search", - json=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Tavily web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - return [ - SearchResult( - title=item.get("title"), - url=item.get("url"), - snippet=item.get("content"), - favicon=item.get("favicon"), - ) - for item in data.get("results", []) - ] + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.tavily.com/search", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Tavily web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + return [ + SearchResult( + title=item.get("title"), + url=item.get("url"), + snippet=item.get("content"), + favicon=item.get("favicon"), + ) + for item in data.get("results", []) + ] async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]: @@ -207,22 +192,24 @@ async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]: "Authorization": f"Bearer {tavily_key}", "Content-Type": "application/json", } - session = await _get_web_search_session() - async with session.post( - "https://api.tavily.com/extract", - json=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Tavily web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - results: list[dict] = data.get("results", []) - if not results: - raise ValueError("Error: Tavily web searcher does not return any results.") - return results + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.tavily.com/extract", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Tavily web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + results: list[dict] = data.get("results", []) + if not results: + raise ValueError( + "Error: Tavily web searcher does not return any results." + ) + return results async def _bocha_search( @@ -238,28 +225,28 @@ async def _bocha_search( # See: https://github.com/aio-libs/aiohttp/issues/11898 "Accept-Encoding": "gzip, deflate", } - session = await _get_web_search_session() - async with session.post( - "https://api.bochaai.com/v1/web-search", - json=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"BoCha web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - rows = data["data"]["webPages"]["value"] - return [ - SearchResult( - title=item.get("name"), - url=item.get("url"), - snippet=item.get("snippet"), - favicon=item.get("siteIcon"), - ) - for item in rows - ] + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.bochaai.com/v1/web-search", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"BoCha web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + rows = data["data"]["webPages"]["value"] + return [ + SearchResult( + title=item.get("name"), + url=item.get("url"), + snippet=item.get("snippet"), + favicon=item.get("siteIcon"), + ) + for item in rows + ] async def _brave_search( @@ -271,27 +258,27 @@ async def _brave_search( "Accept": "application/json", "X-Subscription-Token": brave_key, } - session = await _get_web_search_session() - async with session.get( - "https://api.search.brave.com/res/v1/web/search", - params=payload, - headers=header, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Brave web search failed: {reason}, status: {response.status}", - ) - data = await response.json() - rows = data.get("web", {}).get("results", []) - return [ - SearchResult( - title=item.get("title", ""), - url=item.get("url", ""), - snippet=item.get("description", ""), - ) - for item in rows - ] + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + "https://api.search.brave.com/res/v1/web/search", + params=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Brave web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + rows = data.get("web", {}).get("results", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=item.get("description", ""), + ) + for item in rows + ] async def _firecrawl_search( @@ -332,18 +319,18 @@ async def _firecrawl_post( "Authorization": f"Bearer {firecrawl_key}", "Content-Type": "application/json", } - session = await _get_web_search_session() - async with session.post( - f"{_FIRECRAWL_BASE_URL}/{endpoint}", - json=payload, - headers=headers, - ) as response: - if response.status != 200: - reason = await response.text() - raise FirecrawlAPIError( - f"Firecrawl {endpoint} failed: {reason}, status: {response.status}", - ) - return await response.json() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + f"{_FIRECRAWL_BASE_URL}/{endpoint}", + json=payload, + headers=headers, + ) as response: + if response.status != 200: + reason = await response.text() + raise FirecrawlAPIError( + f"Firecrawl {endpoint} failed: {reason}, status: {response.status}", + ) + return await response.json() async def _baidu_search( @@ -359,29 +346,29 @@ async def _baidu_search( "X-Appbuilder-Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } - session = await _get_web_search_session() - async with session.post( - "https://qianfan.baidubce.com/v2/ai_search/web_search", - json=payload, - headers=headers, - ) as response: - if response.status != 200: - reason = await response.text() - raise Exception( - f"Baidu AI Search failed: {reason}, status: {response.status}", - ) - data = await response.json() - references = data.get("references", []) - return [ - SearchResult( - title=item.get("title", ""), - url=item.get("url", ""), - snippet=item.get("content", ""), - favicon=item.get("icon"), - ) - for item in references - if item.get("url") - ] + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://qianfan.baidubce.com/v2/ai_search/web_search", + json=payload, + headers=headers, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Baidu AI Search failed: {reason}, status: {response.status}", + ) + data = await response.json() + references = data.get("references", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=item.get("content", ""), + favicon=item.get("icon"), + ) + for item in references + if item.get("url") + ] @builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG) diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index 3c97e5c4ad..a20676ed54 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -156,10 +156,11 @@ async def test_firecrawl_post_uses_shared_request_setup(monkeypatch): _FakeFirecrawlResponse(status=200, json_data={"success": True}) ) - async def fake_get_web_search_session(): + def fake_client_session(*, trust_env): + session.trust_env = trust_env return session - monkeypatch.setattr(tools, "_get_web_search_session", fake_get_web_search_session) + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) result = await tools._firecrawl_post( {"websearch_firecrawl_key": ["firecrawl-key"]}, @@ -168,6 +169,9 @@ async def fake_get_web_search_session(): ) assert result == {"success": True} + assert session.trust_env is True + assert session.entered is True + assert session.exited is True assert session.posted == { "url": "https://api.firecrawl.dev/v2/search", "json": {"query": "AstrBot"}, @@ -184,10 +188,11 @@ async def test_firecrawl_post_raises_api_error_for_http_errors(monkeypatch): _FakeFirecrawlResponse(status=401, text_data="Unauthorized") ) - async def fake_get_web_search_session(): + def fake_client_session(*, trust_env): + session.trust_env = trust_env return session - monkeypatch.setattr(tools, "_get_web_search_session", fake_get_web_search_session) + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) with pytest.raises( tools.FirecrawlAPIError, @@ -199,6 +204,10 @@ async def fake_get_web_search_session(): {"url": "https://example.com"}, ) + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + class _FakeFirecrawlResponse: def __init__(self, status=200, json_data=None, text_data=""): @@ -223,12 +232,16 @@ class _FakeFirecrawlSession: def __init__(self, response): self.response = response self.trust_env = None + self.entered = False + self.exited = False self.posted = None async def __aenter__(self): + self.entered = True return self async def __aexit__(self, exc_type, exc, tb): + self.exited = True return None def post(self, url, json, headers): From b31bb0effcc97b7abda00487323298f12c348c14 Mon Sep 17 00:00:00 2001 From: wjiajian <71640797+wjiajian@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:18:50 +0800 Subject: [PATCH 5/7] feat: update Firecrawl search to handle grouped web data response and add corresponding tests --- astrbot/core/tools/web_search_tools.py | 2 ++ tests/unit/test_web_search_tools.py | 33 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index 6f83abaf35..1e3b4404b8 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -287,6 +287,8 @@ async def _firecrawl_search( ) -> list[SearchResult]: data = await _firecrawl_post(provider_settings, "search", payload) rows = data.get("data", []) + if isinstance(rows, dict): + rows = rows.get("web", []) return [ SearchResult( title=item.get("title", ""), diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index a20676ed54..45ad5c9069 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -93,6 +93,39 @@ async def fake_firecrawl_post(provider_settings, endpoint, payload): ] +@pytest.mark.asyncio +async def test_firecrawl_search_maps_v2_grouped_web_data(monkeypatch): + async def fake_firecrawl_post(provider_settings, endpoint, payload): + assert endpoint == "search" + assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] + assert payload == {"query": "AstrBot", "limit": 5, "sources": ["web"]} + return { + "success": True, + "data": { + "web": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ] + }, + } + + monkeypatch.setattr(tools, "_firecrawl_post", fake_firecrawl_post) + + results = await tools._firecrawl_search( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"query": "AstrBot", "limit": 5, "sources": ["web"]}, + ) + + assert results == [ + tools.SearchResult( + title="AstrBot", url="https://example.com", snippet="Search result" + ) + ] + + @pytest.mark.asyncio async def test_firecrawl_search_payload_omits_tbs_and_handles_null_limit(monkeypatch): async def fake_firecrawl_search(provider_settings, payload): From 866caf2f2ee6e04737e51456f145436663143dd9 Mon Sep 17 00:00:00 2001 From: wjiajian <71640797+wjiajian@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:24:06 +0800 Subject: [PATCH 6/7] feat: refactor Firecrawl web search to use aiohttp.ClientSession for improved error handling and session management --- astrbot/core/tools/web_search_tools.py | 85 ++++++------ tests/unit/test_web_search_tools.py | 174 +++++++++++++++++++------ 2 files changed, 178 insertions(+), 81 deletions(-) diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index 1e3b4404b8..22577a69e7 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -76,11 +76,6 @@ async def get(self, provider_settings: dict) -> str: _BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha") _BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave") _FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") -_FIRECRAWL_BASE_URL = "https://api.firecrawl.dev/v2" - - -class FirecrawlAPIError(RuntimeError): - pass def _coerce_int(value, default: int, minimum: int, maximum: int) -> int: @@ -285,54 +280,66 @@ async def _firecrawl_search( provider_settings: dict, payload: dict, ) -> list[SearchResult]: - data = await _firecrawl_post(provider_settings, "search", payload) - rows = data.get("data", []) - if isinstance(rows, dict): - rows = rows.get("web", []) - return [ - SearchResult( - title=item.get("title", ""), - url=item.get("url", ""), - snippet=( - item.get("description") - or item.get("snippet") - or item.get("markdown") - or "" - ), - ) - for item in rows - if item.get("url") - ] + firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) + header = { + "Authorization": f"Bearer {firecrawl_key}", + "Content-Type": "application/json", + } + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.firecrawl.dev/v2/search", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Firecrawl web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + rows = data.get("data", []) + if isinstance(rows, dict): + rows = rows.get("web", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=( + item.get("description") + or item.get("snippet") + or item.get("markdown") + or "" + ), + ) + for item in rows + if item.get("url") + ] async def _firecrawl_scrape(provider_settings: dict, payload: dict) -> dict: - data = await _firecrawl_post(provider_settings, "scrape", payload) - result = data.get("data", {}) - if not result: - raise ValueError("Error: Firecrawl web scraper does not return any results.") - return result - - -async def _firecrawl_post( - provider_settings: dict, endpoint: str, payload: dict -) -> dict: firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) - headers = { + header = { "Authorization": f"Bearer {firecrawl_key}", "Content-Type": "application/json", } async with aiohttp.ClientSession(trust_env=True) as session: async with session.post( - f"{_FIRECRAWL_BASE_URL}/{endpoint}", + "https://api.firecrawl.dev/v2/scrape", json=payload, - headers=headers, + headers=header, ) as response: if response.status != 200: reason = await response.text() - raise FirecrawlAPIError( - f"Firecrawl {endpoint} failed: {reason}, status: {response.status}", + raise Exception( + f"Firecrawl web scraper failed: {reason}, status: {response.status}", + ) + data = await response.json() + result = data.get("data", {}) + if not result: + raise ValueError( + "Error: Firecrawl web scraper does not return any results." ) - return await response.json() + return result async def _baidu_search( diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index 45ad5c9069..cf61838856 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -64,28 +64,41 @@ async def fake_firecrawl_search(provider_settings, payload): @pytest.mark.asyncio async def test_firecrawl_search_maps_v2_data_list(monkeypatch): - async def fake_firecrawl_post(provider_settings, endpoint, payload): - assert endpoint == "search" - assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] - assert payload == {"query": "AstrBot", "limit": 5, "sources": ["web"]} - return { - "success": True, - "data": [ - { - "title": "AstrBot", - "url": "https://example.com", - "description": "Search result", - } - ], - } + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ], + }, + ) + ) - monkeypatch.setattr(tools, "_firecrawl_post", fake_firecrawl_post) + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) results = await tools._firecrawl_search( {"websearch_firecrawl_key": ["firecrawl-key"]}, {"query": "AstrBot", "limit": 5, "sources": ["web"]}, ) + assert session.posted == { + "url": "https://api.firecrawl.dev/v2/search", + "json": {"query": "AstrBot", "limit": 5, "sources": ["web"]}, + "headers": { + "Authorization": "Bearer firecrawl-key", + "Content-Type": "application/json", + }, + } assert results == [ tools.SearchResult( title="AstrBot", url="https://example.com", snippet="Search result" @@ -95,24 +108,29 @@ async def fake_firecrawl_post(provider_settings, endpoint, payload): @pytest.mark.asyncio async def test_firecrawl_search_maps_v2_grouped_web_data(monkeypatch): - async def fake_firecrawl_post(provider_settings, endpoint, payload): - assert endpoint == "search" - assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] - assert payload == {"query": "AstrBot", "limit": 5, "sources": ["web"]} - return { - "success": True, - "data": { - "web": [ - { - "title": "AstrBot", - "url": "https://example.com", - "description": "Search result", - } - ] + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": { + "web": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ] + }, }, - } + ) + ) - monkeypatch.setattr(tools, "_firecrawl_post", fake_firecrawl_post) + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) results = await tools._firecrawl_search( {"websearch_firecrawl_key": ["firecrawl-key"]}, @@ -184,9 +202,21 @@ async def fake_firecrawl_scrape(provider_settings, payload): @pytest.mark.asyncio -async def test_firecrawl_post_uses_shared_request_setup(monkeypatch): +async def test_firecrawl_search_uses_session_context(monkeypatch): session = _FakeFirecrawlSession( - _FakeFirecrawlResponse(status=200, json_data={"success": True}) + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ], + }, + ) ) def fake_client_session(*, trust_env): @@ -195,13 +225,11 @@ def fake_client_session(*, trust_env): monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) - result = await tools._firecrawl_post( + await tools._firecrawl_search( {"websearch_firecrawl_key": ["firecrawl-key"]}, - "search", {"query": "AstrBot"}, ) - assert result == {"success": True} assert session.trust_env is True assert session.entered is True assert session.exited is True @@ -216,7 +244,70 @@ def fake_client_session(*, trust_env): @pytest.mark.asyncio -async def test_firecrawl_post_raises_api_error_for_http_errors(monkeypatch): +async def test_firecrawl_search_raises_error_for_http_errors(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse(status=401, text_data="Unauthorized") + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + with pytest.raises( + Exception, + match="Firecrawl web search failed: Unauthorized, status: 401", + ): + await tools._firecrawl_search( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"query": "AstrBot"}, + ) + + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + + +@pytest.mark.asyncio +async def test_firecrawl_scrape_uses_request_setup(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": {"url": "https://example.com", "markdown": "# Example"}, + }, + ) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + result = await tools._firecrawl_scrape( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"url": "https://example.com", "formats": ["markdown"]}, + ) + + assert result == {"url": "https://example.com", "markdown": "# Example"} + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + assert session.posted == { + "url": "https://api.firecrawl.dev/v2/scrape", + "json": {"url": "https://example.com", "formats": ["markdown"]}, + "headers": { + "Authorization": "Bearer firecrawl-key", + "Content-Type": "application/json", + }, + } + + +@pytest.mark.asyncio +async def test_firecrawl_scrape_raises_error_for_http_errors(monkeypatch): session = _FakeFirecrawlSession( _FakeFirecrawlResponse(status=401, text_data="Unauthorized") ) @@ -228,13 +319,12 @@ def fake_client_session(*, trust_env): monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) with pytest.raises( - tools.FirecrawlAPIError, - match="Firecrawl scrape failed: Unauthorized, status: 401", + Exception, + match="Firecrawl web scraper failed: Unauthorized, status: 401", ): - await tools._firecrawl_post( + await tools._firecrawl_scrape( {"websearch_firecrawl_key": ["firecrawl-key"]}, - "scrape", - {"url": "https://example.com"}, + {"url": "https://example.com", "formats": ["markdown"]}, ) assert session.trust_env is True From ba478c51ba4f3fda127031629bb235758130d42e Mon Sep 17 00:00:00 2001 From: wjiajian <71640797+wjiajian@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:38:30 +0800 Subject: [PATCH 7/7] feat: remove unused coercion function and update Firecrawl search to use default limit in payload --- astrbot/core/tools/web_search_tools.py | 14 +------------- tests/unit/test_web_search_tools.py | 3 +-- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index 22577a69e7..ebd13d0102 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -78,16 +78,6 @@ async def get(self, provider_settings: dict) -> str: _FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") -def _coerce_int(value, default: int, minimum: int, maximum: int) -> int: - if value is None: - value = default - try: - coerced = int(value) - except (TypeError, ValueError): - coerced = default - return max(minimum, min(maximum, coerced)) - - def normalize_legacy_web_search_config(cfg) -> None: provider_settings = cfg.get("provider_settings") if not provider_settings: @@ -671,11 +661,9 @@ async def call(self, context, **kwargs) -> ToolExecResult: if not provider_settings.get("websearch_firecrawl_key", []): return "Error: Firecrawl API key is not configured in AstrBot." - limit = _coerce_int(kwargs.get("limit"), 5, 1, 100) - payload = { "query": kwargs["query"], - "limit": limit, + "limit": kwargs.get("limit", 5), "sources": ["web"], } for key in ("location", "country", "timeout"): diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index cf61838856..c0ac3cf800 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -145,7 +145,7 @@ def fake_client_session(*, trust_env): @pytest.mark.asyncio -async def test_firecrawl_search_payload_omits_tbs_and_handles_null_limit(monkeypatch): +async def test_firecrawl_search_payload_omits_tbs_and_uses_default_limit(monkeypatch): async def fake_firecrawl_search(provider_settings, payload): assert payload == { "query": "AstrBot", @@ -170,7 +170,6 @@ async def fake_firecrawl_search(provider_settings, payload): result = await tool.call( context, query="AstrBot", - limit=None, tbs="qdr:d", country="US", )