diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index a9ddb2e7b9..87cb2db064 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -77,6 +77,8 @@ BaiduWebSearchTool, BochaWebSearchTool, BraveWebSearchTool, + FirecrawlExtractWebPageTool, + FirecrawlWebSearchTool, TavilyExtractWebPageTool, TavilyWebSearchTool, normalize_legacy_web_search_config, @@ -1047,6 +1049,9 @@ async def _apply_web_search_tools( req.func_tool.add_tool(tool_mgr.get_builtin_tool(BochaWebSearchTool)) elif provider == "brave": req.func_tool.add_tool(tool_mgr.get_builtin_tool(BraveWebSearchTool)) + elif provider == "firecrawl": + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlWebSearchTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlExtractWebPageTool)) elif provider == "baidu_ai_search": req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool)) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 903f6c445f..cd1c81a888 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -3202,6 +3202,7 @@ class ChatProviderTemplate(TypedDict): "baidu_ai_search", "bocha", "brave", + "firecrawl", ], "condition": { "provider_settings.web_search": True, @@ -3237,6 +3238,16 @@ class ChatProviderTemplate(TypedDict): "provider_settings.web_search": True, }, }, + "provider_settings.websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "type": "list", + "items": {"type": "string"}, + "hint": "可添加多个 Key 进行轮询。", + "condition": { + "provider_settings.websearch_provider": "firecrawl", + "provider_settings.web_search": True, + }, + }, "provider_settings.websearch_baidu_app_builder_key": { "description": "百度千帆智能云 APP Builder API Key", "type": "string", diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index f2d9474906..eb01ff01dc 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -519,6 +519,42 @@ async def get_models(self): except NotFoundError as e: raise Exception(f"获取模型列表失败:{e}") + @staticmethod + def _sanitize_assistant_messages(payloads: dict) -> None: + """在请求发送前过滤/规范化空的 assistant 消息。 + + 严格 API(Moonshot、DeepSeek Reasoner 等)会在 assistant 消息同时缺少 + ``content`` 和 ``tool_calls`` 时返回 400。把 ``""`` / ``None`` / ``[]`` + 都视作空内容:无 tool_calls 时整条过滤掉;有 tool_calls 时将 content + 设为 ``None`` 以符合 OpenAI 规范。就地修改 ``payloads["messages"]``。 + """ + messages = payloads.get("messages") + if not isinstance(messages, list): + return + + def _is_empty(content: Any) -> bool: + return content is None or content == "" or content == [] + + cleaned: list[Any] = [] + for idx, msg in enumerate(messages): + if not isinstance(msg, dict) or msg.get("role") != "assistant": + cleaned.append(msg) + continue + + content = msg.get("content") + tool_calls = msg.get("tool_calls") + + if _is_empty(content) and not tool_calls: + logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)") + continue + + if _is_empty(content) and tool_calls: + msg["content"] = None + + cleaned.append(msg) + + payloads["messages"] = cleaned + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: model = payloads.get("model", "").lower() @@ -548,26 +584,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: model = payloads.get("model", "").lower() - if "messages" in payloads and isinstance(payloads["messages"], list): - cleaned_messages = [] - for idx, msg in enumerate(payloads["messages"]): - # 过滤空的 assistant 消息,防止严格 API(如 Moonshot)返回 400 错误 - if msg.get("role") == "assistant": - content = msg.get("content") - tool_calls = msg.get("tool_calls") - - # 情况1: 空/null content 且无 tool_calls -> 过滤掉 - if not tool_calls and (content == "" or content is None): - logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)") - continue - - # 情况2: 空 content 但有 tool_calls -> 设为 None (符合 OpenAI 规范) - if content == "" and tool_calls: - msg["content"] = None - - cleaned_messages.append(msg) - - payloads["messages"] = cleaned_messages + self._sanitize_assistant_messages(payloads) completion = await self.client.chat.completions.create( **payloads, @@ -619,6 +636,8 @@ async def _query_stream( del payloads[key] self._apply_provider_specific_extra_body_overrides(extra_body) + self._sanitize_assistant_messages(payloads) + stream = await self.client.chat.completions.create( **payloads, stream=True, diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index ca89bc17d4..ebd13d0102 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -19,6 +19,8 @@ "tavily_extract_web_page", "web_search_bocha", "web_search_brave", + "web_search_firecrawl", + "firecrawl_extract_web_page", ] _TAVILY_WEB_SEARCH_TOOL_CONFIG = { "provider_settings.web_search": True, @@ -32,6 +34,10 @@ "provider_settings.web_search": True, "provider_settings.websearch_provider": "brave", } +_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG = { + "provider_settings.web_search": True, + "provider_settings.websearch_provider": "firecrawl", +} _BAIDU_WEB_SEARCH_TOOL_CONFIG = { "provider_settings.web_search": True, "provider_settings.websearch_provider": "baidu_ai_search", @@ -69,6 +75,7 @@ async def get(self, provider_settings: dict) -> str: _TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily") _BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha") _BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave") +_FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") def normalize_legacy_web_search_config(cfg) -> None: @@ -91,6 +98,7 @@ def normalize_legacy_web_search_config(cfg) -> None: "websearch_tavily_key", "websearch_bocha_key", "websearch_brave_key", + "websearch_firecrawl_key", ): value = provider_settings.get(setting_name) if isinstance(value, str): @@ -258,6 +266,72 @@ async def _brave_search( ] +async def _firecrawl_search( + provider_settings: dict, + payload: dict, +) -> list[SearchResult]: + firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) + header = { + "Authorization": f"Bearer {firecrawl_key}", + "Content-Type": "application/json", + } + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.firecrawl.dev/v2/search", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Firecrawl web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + rows = data.get("data", []) + if isinstance(rows, dict): + rows = rows.get("web", []) + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=( + item.get("description") + or item.get("snippet") + or item.get("markdown") + or "" + ), + ) + for item in rows + if item.get("url") + ] + + +async def _firecrawl_scrape(provider_settings: dict, payload: dict) -> dict: + firecrawl_key = await _FIRECRAWL_KEY_ROTATOR.get(provider_settings) + header = { + "Authorization": f"Bearer {firecrawl_key}", + "Content-Type": "application/json", + } + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.firecrawl.dev/v2/scrape", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Firecrawl web scraper failed: {reason}, status: {response.status}", + ) + data = await response.json() + result = data.get("data", {}) + if not result: + raise ValueError( + "Error: Firecrawl web scraper does not return any results." + ) + return result + + async def _baidu_search( provider_settings: dict, payload: dict, @@ -548,6 +622,124 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) +@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG) +@pydantic_dataclass +class FirecrawlWebSearchTool(FunctionTool[AstrAgentContext]): + name: str = "web_search_firecrawl" + description: str = ( + "A web search tool based on Firecrawl Search API, used to retrieve web " + "pages related to the user's query." + ) + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Required. Search query."}, + "limit": { + "type": "integer", + "description": "Optional. Number of results to return. Range: 1-100. Default is 5.", + }, + "location": { + "type": "string", + "description": "Optional. Geographic location for search results.", + }, + "country": { + "type": "string", + "description": 'Optional. Country code for search results, for example "US" or "CN".', + }, + "timeout": { + "type": "integer", + "description": "Optional. Request timeout in milliseconds.", + }, + }, + "required": ["query"], + } + ) + + async def call(self, context, **kwargs) -> ToolExecResult: + _, provider_settings, _ = _get_runtime(context) + if not provider_settings.get("websearch_firecrawl_key", []): + return "Error: Firecrawl API key is not configured in AstrBot." + + payload = { + "query": kwargs["query"], + "limit": kwargs.get("limit", 5), + "sources": ["web"], + } + for key in ("location", "country", "timeout"): + if kwargs.get(key): + payload[key] = kwargs[key] + + results = await _firecrawl_search(provider_settings, payload) + if not results: + return "Error: Firecrawl web searcher does not return any results." + return _search_result_payload(results) + + +@builtin_tool(config=_FIRECRAWL_WEB_SEARCH_TOOL_CONFIG) +@pydantic_dataclass +class FirecrawlExtractWebPageTool(FunctionTool[AstrAgentContext]): + name: str = "firecrawl_extract_web_page" + description: str = "Extract the content of a web page using Firecrawl." + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "Required. A URL to extract content from.", + }, + "format": { + "type": "string", + "description": 'Optional. Output format, one of "markdown", "html", "rawHtml", "summary". Default is "markdown".', + }, + "only_main_content": { + "type": "boolean", + "description": "Optional. Whether to extract only the main page content. Default is true.", + }, + "timeout": { + "type": "integer", + "description": "Optional. Request timeout in milliseconds.", + }, + "max_age": { + "type": "integer", + "description": "Optional. Maximum cache age in milliseconds.", + }, + }, + "required": ["url"], + } + ) + + async def call(self, context, **kwargs) -> ToolExecResult: + _, provider_settings, _ = _get_runtime(context) + if not provider_settings.get("websearch_firecrawl_key", []): + return "Error: Firecrawl API key is not configured in AstrBot." + + url = str(kwargs.get("url", "")).strip() + if not url: + return "Error: url must be a non-empty string." + + output_format = kwargs.get("format", "markdown") + if output_format not in ["markdown", "html", "rawHtml", "summary"]: + output_format = "markdown" + + payload = { + "url": url, + "formats": [output_format], + "onlyMainContent": kwargs.get("only_main_content", True), + } + if kwargs.get("timeout"): + payload["timeout"] = kwargs["timeout"] + if kwargs.get("max_age"): + payload["maxAge"] = kwargs["max_age"] + + result = await _firecrawl_scrape(provider_settings, payload) + content = result.get(output_format, "") + result_url = result.get("url") or url + ret = f"URL: {result_url}\nContent: {content}" if content else "" + return ret or "Error: Firecrawl web scraper does not return any results." + + @builtin_tool(config=_BAIDU_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class BaiduWebSearchTool(FunctionTool[AstrAgentContext]): diff --git a/dashboard/src/components/chat/MessageListDEPRECATED.vue b/dashboard/src/components/chat/MessageListDEPRECATED.vue index 128d0d97b1..1271042b39 100644 --- a/dashboard/src/components/chat/MessageListDEPRECATED.vue +++ b/dashboard/src/components/chat/MessageListDEPRECATED.vue @@ -303,7 +303,7 @@ export default { part.tool_calls.forEach(toolCall => { // 检查是否是支持引用解析的 web_search 工具调用 if ( - !['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave'].includes(toolCall.name) || + !['web_search_baidu', 'web_search_tavily', 'web_search_bocha', 'web_search_brave', 'web_search_firecrawl'].includes(toolCall.name) || !toolCall.result ) { return; diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 373182fc15..4f35dd2859 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -125,6 +125,10 @@ "description": "Brave Search API Key", "hint": "Multiple keys can be added for rotation." }, + "websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "hint": "Multiple keys can be added for rotation." + }, "websearch_baidu_app_builder_key": { "description": "Baidu Qianfan Smart Cloud APP Builder API Key", "hint": "Reference: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index c578f79d1c..08d11aed6a 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -125,6 +125,10 @@ "description": "API-ключ Brave Search", "hint": "Можно добавить несколько ключей для ротации." }, + "websearch_firecrawl_key": { + "description": "API-ключ Firecrawl", + "hint": "Можно добавить несколько ключей для ротации." + }, "websearch_baidu_app_builder_key": { "description": "API-ключ Baidu Qianfan APP Builder", "hint": "Ссылка: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index dd8711345c..8495f9ba1a 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -127,6 +127,10 @@ "description": "Brave Search API Key", "hint": "可添加多个 Key 进行轮询。" }, + "websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "hint": "可添加多个 Key 进行轮询。" + }, "websearch_baidu_app_builder_key": { "description": "百度千帆智能云 APP Builder API Key", "hint": "参考:[https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 83e18137c4..ec5e79f492 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1618,3 +1618,109 @@ async def fake_create(**kwargs): assert messages[2] == {"role": "user", "content": "hello"} finally: await provider.terminate() + + +@pytest.mark.asyncio +async def test_query_stream_filters_empty_assistant_message(monkeypatch): + """Regression for #7721: streaming path must also filter empty assistant messages. + + Previously only ``_query`` sanitized the payload; ``_query_stream`` forwarded + the raw history and strict providers (e.g. DeepSeek Reasoner) returned 400 on + the next turn after a tool call whose assistant entry had reasoning only. + """ + provider = _make_provider() + try: + captured_kwargs = {} + + async def fake_stream(): + yield ChatCompletionChunk.model_validate( + { + "id": "chatcmpl-stream", + "object": "chat.completion.chunk", + "created": 0, + "model": "deepseek-reasoner", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + } + ], + } + ) + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return fake_stream() + + monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) + + payloads = { + "model": "deepseek-reasoner", + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": ""}, # should be filtered + {"role": "user", "content": "world"}, + ], + } + + async for _ in provider._query_stream(payloads=payloads, tools=None): + pass + + messages = captured_kwargs["messages"] + assert len(messages) == 2 + assert messages[0] == {"role": "user", "content": "hello"} + assert messages[1] == {"role": "user", "content": "world"} + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_query_filters_empty_list_content_assistant_message(monkeypatch): + """Empty-list content (``content == []``) must also be filtered, not just ``""`` / ``None``.""" + provider = _make_provider() + try: + captured_kwargs = {} + + async def fake_create(**kwargs): + captured_kwargs.update(kwargs) + return ChatCompletion.model_validate( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 0, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + } + ) + + monkeypatch.setattr(provider.client.chat.completions, "create", fake_create) + + payloads = { + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": []}, # should be filtered + {"role": "user", "content": "again"}, + ], + } + + await provider._query(payloads=payloads, tools=None) + + messages = captured_kwargs["messages"] + assert len(messages) == 2 + assert messages[0] == {"role": "user", "content": "hi"} + assert messages[1] == {"role": "user", "content": "again"} + finally: + await provider.terminate() diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 6ca1a3f2aa..5a5bceae15 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -398,6 +398,37 @@ async def test_apply_web_search_tools_uses_builtin_tool_manager( assert req.func_tool is not None assert req.func_tool.get_tool("web_search_baidu") is builtin_tool + @pytest.mark.asyncio + async def test_apply_web_search_tools_adds_firecrawl_search_and_extract_tools( + self, mock_event, mock_context + ): + """Test Firecrawl web search injects search and extract tools.""" + module = ama + req = ProviderRequest() + mock_context.get_config.return_value = { + "provider_settings": { + "web_search": True, + "websearch_provider": "firecrawl", + } + } + search_tool = MagicMock(spec=FunctionTool) + search_tool.name = "web_search_firecrawl" + extract_tool = MagicMock(spec=FunctionTool) + extract_tool.name = "firecrawl_extract_web_page" + tool_mgr = MagicMock() + tool_mgr.get_builtin_tool.side_effect = [search_tool, extract_tool] + mock_context.get_llm_tool_manager.return_value = tool_mgr + + await module._apply_web_search_tools(mock_event, req, mock_context) + + assert tool_mgr.get_builtin_tool.call_args_list == [ + ((module.FirecrawlWebSearchTool,),), + ((module.FirecrawlExtractWebPageTool,),), + ] + assert req.func_tool is not None + assert req.func_tool.get_tool("web_search_firecrawl") is search_tool + assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool + def test_proactive_cron_job_tools_uses_builtin_tool_manager(self, mock_context): """Test cron tool injection through the builtin tool manager.""" module = ama diff --git a/tests/unit/test_func_tool_manager.py b/tests/unit/test_func_tool_manager.py index 908810cdb3..c87a2de085 100644 --- a/tests/unit/test_func_tool_manager.py +++ b/tests/unit/test_func_tool_manager.py @@ -2,6 +2,8 @@ from astrbot.core.provider.func_tool_manager import FunctionToolManager from astrbot.core.tools.computer_tools.shell import ExecuteShellTool from astrbot.core.tools.message_tools import SendMessageToUserTool +from astrbot.core.tools.web_search_tools import FirecrawlExtractWebPageTool +from astrbot.core.tools.web_search_tools import FirecrawlWebSearchTool def test_get_builtin_tool_by_class_returns_cached_instance(): @@ -38,3 +40,15 @@ def test_computer_tools_are_registered_as_builtin_tools(): assert tool.name == "astrbot_execute_shell" assert manager.is_builtin_tool("astrbot_execute_shell") is True + + +def test_firecrawl_tools_are_registered_as_builtin_tools(): + manager = FunctionToolManager() + + search_tool = manager.get_builtin_tool(FirecrawlWebSearchTool) + extract_tool = manager.get_builtin_tool(FirecrawlExtractWebPageTool) + + assert search_tool.name == "web_search_firecrawl" + assert extract_tool.name == "firecrawl_extract_web_page" + assert manager.is_builtin_tool("web_search_firecrawl") is True + assert manager.is_builtin_tool("firecrawl_extract_web_page") is True diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py new file mode 100644 index 0000000000..c0ac3cf800 --- /dev/null +++ b/tests/unit/test_web_search_tools.py @@ -0,0 +1,380 @@ +import json +from types import SimpleNamespace + +import pytest + +from astrbot.core.tools import web_search_tools as tools + + +class _FakeConfig(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.saved = False + + def save_config(self): + self.saved = True + + +def test_normalize_legacy_web_search_config_migrates_firecrawl_key(): + config = _FakeConfig( + {"provider_settings": {"websearch_firecrawl_key": "firecrawl-key"}} + ) + + tools.normalize_legacy_web_search_config(config) + + assert config["provider_settings"]["websearch_firecrawl_key"] == ["firecrawl-key"] + assert config.saved is True + + +@pytest.mark.asyncio +async def test_firecrawl_search_maps_web_results(monkeypatch): + async def fake_firecrawl_search(provider_settings, payload): + assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] + assert payload == { + "query": "AstrBot", + "limit": 3, + "sources": ["web"], + "country": "US", + } + return [ + tools.SearchResult( + title="AstrBot", + url="https://example.com", + snippet="Search result", + ) + ] + + monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search) + tool = tools.FirecrawlWebSearchTool() + context = _context_with_provider_settings( + {"websearch_firecrawl_key": ["firecrawl-key"]} + ) + + result = await tool.call(context, query="AstrBot", limit=3, country="US") + + assert json.loads(result)["results"] == [ + { + "title": "AstrBot", + "url": "https://example.com", + "snippet": "Search result", + "index": json.loads(result)["results"][0]["index"], + } + ] + + +@pytest.mark.asyncio +async def test_firecrawl_search_maps_v2_data_list(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ], + }, + ) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + results = await tools._firecrawl_search( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"query": "AstrBot", "limit": 5, "sources": ["web"]}, + ) + + assert session.posted == { + "url": "https://api.firecrawl.dev/v2/search", + "json": {"query": "AstrBot", "limit": 5, "sources": ["web"]}, + "headers": { + "Authorization": "Bearer firecrawl-key", + "Content-Type": "application/json", + }, + } + assert results == [ + tools.SearchResult( + title="AstrBot", url="https://example.com", snippet="Search result" + ) + ] + + +@pytest.mark.asyncio +async def test_firecrawl_search_maps_v2_grouped_web_data(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": { + "web": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ] + }, + }, + ) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + results = await tools._firecrawl_search( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"query": "AstrBot", "limit": 5, "sources": ["web"]}, + ) + + assert results == [ + tools.SearchResult( + title="AstrBot", url="https://example.com", snippet="Search result" + ) + ] + + +@pytest.mark.asyncio +async def test_firecrawl_search_payload_omits_tbs_and_uses_default_limit(monkeypatch): + async def fake_firecrawl_search(provider_settings, payload): + assert payload == { + "query": "AstrBot", + "limit": 5, + "sources": ["web"], + "country": "US", + } + return [ + tools.SearchResult( + title="AstrBot", + url="https://example.com", + snippet="Search result", + ) + ] + + monkeypatch.setattr(tools, "_firecrawl_search", fake_firecrawl_search) + tool = tools.FirecrawlWebSearchTool() + context = _context_with_provider_settings( + {"websearch_firecrawl_key": ["firecrawl-key"]} + ) + + result = await tool.call( + context, + query="AstrBot", + tbs="qdr:d", + country="US", + ) + + assert json.loads(result)["results"][0]["url"] == "https://example.com" + assert "tbs" not in tool.parameters["properties"] + + +@pytest.mark.asyncio +async def test_firecrawl_extract_returns_scraped_markdown(monkeypatch): + async def fake_firecrawl_scrape(provider_settings, payload): + assert provider_settings["websearch_firecrawl_key"] == ["firecrawl-key"] + assert payload == { + "url": "https://example.com", + "formats": ["markdown"], + "onlyMainContent": True, + } + return {"url": "https://example.com", "markdown": "# Example"} + + monkeypatch.setattr(tools, "_firecrawl_scrape", fake_firecrawl_scrape) + tool = tools.FirecrawlExtractWebPageTool() + context = _context_with_provider_settings( + {"websearch_firecrawl_key": ["firecrawl-key"]} + ) + + result = await tool.call(context, url="https://example.com") + + assert result == "URL: https://example.com\nContent: # Example" + + +@pytest.mark.asyncio +async def test_firecrawl_search_uses_session_context(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "Search result", + } + ], + }, + ) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + await tools._firecrawl_search( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"query": "AstrBot"}, + ) + + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + assert session.posted == { + "url": "https://api.firecrawl.dev/v2/search", + "json": {"query": "AstrBot"}, + "headers": { + "Authorization": "Bearer firecrawl-key", + "Content-Type": "application/json", + }, + } + + +@pytest.mark.asyncio +async def test_firecrawl_search_raises_error_for_http_errors(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse(status=401, text_data="Unauthorized") + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + with pytest.raises( + Exception, + match="Firecrawl web search failed: Unauthorized, status: 401", + ): + await tools._firecrawl_search( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"query": "AstrBot"}, + ) + + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + + +@pytest.mark.asyncio +async def test_firecrawl_scrape_uses_request_setup(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "success": True, + "data": {"url": "https://example.com", "markdown": "# Example"}, + }, + ) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + result = await tools._firecrawl_scrape( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"url": "https://example.com", "formats": ["markdown"]}, + ) + + assert result == {"url": "https://example.com", "markdown": "# Example"} + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + assert session.posted == { + "url": "https://api.firecrawl.dev/v2/scrape", + "json": {"url": "https://example.com", "formats": ["markdown"]}, + "headers": { + "Authorization": "Bearer firecrawl-key", + "Content-Type": "application/json", + }, + } + + +@pytest.mark.asyncio +async def test_firecrawl_scrape_raises_error_for_http_errors(monkeypatch): + session = _FakeFirecrawlSession( + _FakeFirecrawlResponse(status=401, text_data="Unauthorized") + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + with pytest.raises( + Exception, + match="Firecrawl web scraper failed: Unauthorized, status: 401", + ): + await tools._firecrawl_scrape( + {"websearch_firecrawl_key": ["firecrawl-key"]}, + {"url": "https://example.com", "formats": ["markdown"]}, + ) + + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + + +class _FakeFirecrawlResponse: + def __init__(self, status=200, json_data=None, text_data=""): + self.status = status + self.json_data = json_data or {} + self.text_data = text_data + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self): + return self.json_data + + async def text(self): + return self.text_data + + +class _FakeFirecrawlSession: + def __init__(self, response): + self.response = response + self.trust_env = None + self.entered = False + self.exited = False + self.posted = None + + async def __aenter__(self): + self.entered = True + return self + + async def __aexit__(self, exc_type, exc, tb): + self.exited = True + return None + + def post(self, url, json, headers): + self.posted = {"url": url, "json": json, "headers": headers} + return self.response + + +def _context_with_provider_settings(provider_settings): + config = {"provider_settings": provider_settings} + agent_context = SimpleNamespace( + context=SimpleNamespace(get_config=lambda umo: config), + event=SimpleNamespace(unified_msg_origin="test:private:session"), + ) + return SimpleNamespace(context=agent_context)