diff --git a/Dockerfile b/Dockerfile index 7bfb00c7e2..30977605c6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,6 +16,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ curl \ gnupg \ git \ + ripgrep \ && curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \ && apt-get install -y --no-install-recommends nodejs \ && apt-get clean \ diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 9d0b0ffce1..6e3ba40a98 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -4,9 +4,11 @@ import time import traceback import typing as T +import uuid from collections.abc import AsyncIterator from contextlib import suppress from dataclasses import dataclass, field +from pathlib import Path from mcp.types import ( BlobResourceContents, @@ -25,7 +27,7 @@ from astrbot import logger from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart -from astrbot.core.agent.tool import ToolSet +from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.agent.tool_image_cache import tool_image_cache from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.message.components import Json @@ -45,7 +47,7 @@ from ..context.compressor import ContextCompressor from ..context.config import ContextConfig from ..context.manager import ContextManager -from ..context.token_counter import TokenCounter +from ..context.token_counter import EstimateTokenCounter, TokenCounter from ..hooks import BaseAgentRunHooks from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment from ..response import AgentResponseData, AgentStats @@ -97,6 +99,8 @@ class _ToolExecutionInterrupted(Exception): class ToolLoopAgentRunner(BaseAgentRunner[TContext]): + TOOL_RESULT_MAX_ESTIMATED_TOKENS = 27_500 + TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS = 7000 EMPTY_OUTPUT_RETRY_ATTEMPTS = 3 EMPTY_OUTPUT_RETRY_WAIT_MIN_S = 1 EMPTY_OUTPUT_RETRY_WAIT_MAX_S = 4 @@ -151,6 +155,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): "Otherwise, change strategy, adjust arguments, or explain the limitation " "to the user." ) + TOOL_RESULT_OVERFLOW_NOTICE_TEMPLATE = ( + "Truncated tool output preview shown above. " + "The tool output was too large to include directly and was written to " + "`{overflow_path}`. Use {read_tool_hint} to inspect it. " + "Use a narrower window when reading large files." + ) def _get_persona_custom_error_message(self) -> str | None: """Read persona-level custom error message from event extras when available.""" @@ -206,6 +216,8 @@ async def reset( custom_compressor: ContextCompressor | None = None, tool_schema_mode: str | None = "full", fallback_providers: list[Provider] | None = None, + tool_result_overflow_dir: str | None = None, + read_tool: FunctionTool | None = None, **kwargs: T.Any, ) -> None: self.req = request @@ -217,6 +229,9 @@ async def reset( self.truncate_turns = truncate_turns self.custom_token_counter = custom_token_counter self.custom_compressor = custom_compressor + self.tool_result_overflow_dir = tool_result_overflow_dir + self.read_tool = read_tool + self._tool_result_token_counter = EstimateTokenCounter() # we will do compress when: # 1. before requesting LLM # TODO: 2. after LLM output a tool call @@ -298,6 +313,103 @@ async def reset( self.stats = AgentStats() self.stats.start_time = time.time() + def _read_tool_hint(self) -> str: + if self.read_tool is not None: + return f"`{self.read_tool.name}`" + return "the available file-read tool" + + async def _write_tool_result_overflow_file( + self, + *, + tool_call_id: str, + content: str, + ) -> str: + if self.tool_result_overflow_dir is None: + raise ValueError("tool_result_overflow_dir is not configured") + + overflow_dir = Path(self.tool_result_overflow_dir).resolve(strict=False) + safe_tool_call_id = ( + "".join( + ch if ch.isalnum() or ch in {"-", "_", "."} else "_" + for ch in tool_call_id + ).strip("._") + or "tool_call" + ) + file_name = f"{safe_tool_call_id}_{uuid.uuid4().hex[:8]}.txt" + overflow_path = overflow_dir / file_name + + def _run() -> str: + overflow_dir.mkdir(parents=True, exist_ok=True) + overflow_path.write_text(content, encoding="utf-8") + return str(overflow_path) + + return await asyncio.to_thread(_run) + + async def _materialize_large_tool_result( + self, + *, + tool_call_id: str, + content: str, + ) -> str: + if self.tool_result_overflow_dir is None or self.read_tool is None: + return content + + estimated_tokens = self._tool_result_token_counter.count_tokens( + [Message(role="tool", content=content, tool_call_id=tool_call_id)] + ) + if estimated_tokens <= self.TOOL_RESULT_MAX_ESTIMATED_TOKENS: + return content + + preview = self._truncate_tool_result_preview(content, tool_call_id=tool_call_id) + try: + overflow_path = await self._write_tool_result_overflow_file( + tool_call_id=tool_call_id, + content=content, + ) + except Exception as exc: + logger.warning( + "Failed to spill oversized tool result for %s: %s", + tool_call_id, + exc, + exc_info=True, + ) + error_notice = ( + "Tool output exceeded the inline result limit " + f"({estimated_tokens} estimated tokens > " + f"{self.TOOL_RESULT_MAX_ESTIMATED_TOKENS}) and could not be written " + f"to `{self.tool_result_overflow_dir}`: {exc}" + ) + if not preview: + return error_notice + return f"{preview}\n\n{error_notice}" + + notice = self.TOOL_RESULT_OVERFLOW_NOTICE_TEMPLATE.format( + overflow_path=overflow_path, + read_tool_hint=self._read_tool_hint(), + ) + if not preview: + return notice + return f"{preview}\n\n{notice}" + + def _truncate_tool_result_preview( + self, + content: str, + *, + tool_call_id: str, + ) -> str: + preview = content + while preview: + estimated_tokens = self._tool_result_token_counter.count_tokens( + [Message(role="tool", content=preview, tool_call_id=tool_call_id)] + ) + if estimated_tokens <= self.TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS: + return preview + next_len = len(preview) // 2 + if next_len <= 0: + break + preview = preview[:next_len] + return preview + async def _iter_llm_responses( self, *, include_model: bool = True ) -> T.AsyncGenerator[LLMResponse, None]: @@ -933,9 +1045,14 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: "The tool has returned a data type that is not supported." ) if result_parts: + inline_result = "\n\n".join(result_parts) + inline_result = await self._materialize_large_tool_result( + tool_call_id=func_tool_id, + content=inline_result, + ) _append_tool_call_result( func_tool_id, - "\n\n".join(result_parts) + inline_result + self._build_repeated_tool_call_guidance( func_tool_name, tool_call_streak ), diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 18ac1a446a..d668dfeec9 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -19,12 +19,6 @@ from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_main_agent_resources import ( BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT, - EXECUTE_SHELL_TOOL, - FILE_DOWNLOAD_TOOL, - FILE_UPLOAD_TOOL, - LOCAL_EXECUTE_SHELL_TOOL, - LOCAL_PYTHON_TOOL, - PYTHON_TOOL, ) from astrbot.core.cron.events import CronMessageEvent from astrbot.core.message.components import Image @@ -36,6 +30,17 @@ from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools +from astrbot.core.tools.computer_tools import ( + ExecuteShellTool, + FileDownloadTool, + FileEditTool, + FileReadTool, + FileUploadTool, + FileWriteTool, + GrepTool, + LocalPythonTool, + PythonTool, +) from astrbot.core.tools.message_tools import SendMessageToUserTool from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.history_saver import persist_agent_history @@ -177,18 +182,44 @@ async def _run_in_background() -> None: return @classmethod - def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]: + def _get_runtime_computer_tools( + cls, + runtime: str, + tool_mgr, + ) -> dict[str, FunctionTool]: if runtime == "sandbox": + shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool) + python_tool = tool_mgr.get_builtin_tool(PythonTool) + upload_tool = tool_mgr.get_builtin_tool(FileUploadTool) + download_tool = tool_mgr.get_builtin_tool(FileDownloadTool) + read_tool = tool_mgr.get_builtin_tool(FileReadTool) + write_tool = tool_mgr.get_builtin_tool(FileWriteTool) + edit_tool = tool_mgr.get_builtin_tool(FileEditTool) + grep_tool = tool_mgr.get_builtin_tool(GrepTool) return { - EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL, - PYTHON_TOOL.name: PYTHON_TOOL, - FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL, - FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL, + shell_tool.name: shell_tool, + python_tool.name: python_tool, + upload_tool.name: upload_tool, + download_tool.name: download_tool, + read_tool.name: read_tool, + write_tool.name: write_tool, + edit_tool.name: edit_tool, + grep_tool.name: grep_tool, } if runtime == "local": + shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool) + python_tool = tool_mgr.get_builtin_tool(LocalPythonTool) + read_tool = tool_mgr.get_builtin_tool(FileReadTool) + write_tool = tool_mgr.get_builtin_tool(FileWriteTool) + edit_tool = tool_mgr.get_builtin_tool(FileEditTool) + grep_tool = tool_mgr.get_builtin_tool(GrepTool) return { - LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL, - LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL, + shell_tool.name: shell_tool, + python_tool.name: python_tool, + read_tool.name: read_tool, + write_tool.name: write_tool, + edit_tool.name: edit_tool, + grep_tool.name: grep_tool, } return {} @@ -203,7 +234,15 @@ def _build_handoff_toolset( cfg = ctx.get_config(umo=event.unified_msg_origin) provider_settings = cfg.get("provider_settings", {}) runtime = str(provider_settings.get("computer_use_runtime", "local")) - runtime_computer_tools = cls._get_runtime_computer_tools(runtime) + tool_mgr = ( + ctx.get_llm_tool_manager() + if hasattr(ctx, "get_llm_tool_manager") + else llm_tools + ) + runtime_computer_tools = cls._get_runtime_computer_tools( + runtime, + tool_mgr, + ) # Keep persona semantics aligned with the main agent: tools=None means # "all tools", including runtime computer-use tools. diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index ebb95baa4e..0b74d63c6d 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -9,6 +9,7 @@ import zoneinfo from collections.abc import Coroutine from dataclasses import dataclass, field +from pathlib import Path from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool @@ -20,30 +21,10 @@ from astrbot.core.astr_agent_run_util import AgentRunner from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.astr_main_agent_resources import ( - ANNOTATE_EXECUTION_TOOL, - BROWSER_BATCH_EXEC_TOOL, - BROWSER_EXEC_TOOL, CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT, - CREATE_SKILL_CANDIDATE_TOOL, - CREATE_SKILL_PAYLOAD_TOOL, - EVALUATE_SKILL_CANDIDATE_TOOL, - EXECUTE_SHELL_TOOL, - FILE_DOWNLOAD_TOOL, - FILE_UPLOAD_TOOL, - GET_EXECUTION_HISTORY_TOOL, - GET_SKILL_PAYLOAD_TOOL, - LIST_SKILL_CANDIDATES_TOOL, - LIST_SKILL_RELEASES_TOOL, LIVE_MODE_SYSTEM_PROMPT, LLM_SAFETY_MODE_SYSTEM_PROMPT, - LOCAL_EXECUTE_SHELL_TOOL, - LOCAL_PYTHON_TOOL, - PROMOTE_SKILL_CANDIDATE_TOOL, - PYTHON_TOOL, - ROLLBACK_SKILL_RELEASE_TOOL, - RUN_BROWSER_SKILL_TOOL, SANDBOX_MODE_PROMPT, - SYNC_SKILL_RELEASE_TOOL, TOOL_CALL_PROMPT, TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, ) @@ -56,9 +37,36 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest +from astrbot.core.provider.register import llm_tools from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt from astrbot.core.star.context import Context from astrbot.core.star.star_handler import star_map +from astrbot.core.tools.computer_tools import ( + AnnotateExecutionTool, + BrowserBatchExecTool, + BrowserExecTool, + CreateSkillCandidateTool, + CreateSkillPayloadTool, + EvaluateSkillCandidateTool, + ExecuteShellTool, + FileDownloadTool, + FileEditTool, + FileReadTool, + FileUploadTool, + FileWriteTool, + GetExecutionHistoryTool, + GetSkillPayloadTool, + GrepTool, + ListSkillCandidatesTool, + ListSkillReleasesTool, + LocalPythonTool, + PromoteSkillCandidateTool, + PythonTool, + RollbackSkillReleaseTool, + RunBrowserSkillTool, + SyncSkillReleaseTool, + normalize_umo_for_workspace, +) from astrbot.core.tools.cron_tools import FutureTaskTool from astrbot.core.tools.knowledge_base_tools import ( KnowledgeBaseQueryTool, @@ -73,6 +81,10 @@ TavilyWebSearchTool, normalize_legacy_web_search_config, ) +from astrbot.core.utils.astrbot_path import ( + get_astrbot_system_tmp_path, + get_astrbot_workspaces_path, +) from astrbot.core.utils.file_extract import extract_file_moonshotai from astrbot.core.utils.llm_metadata import LLM_METADATAS from astrbot.core.utils.media_utils import ( @@ -290,11 +302,54 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None: req.prompt = f"{prefix}{req.prompt}" -def _apply_local_env_tools(req: ProviderRequest) -> None: +def _get_workspace_path_for_umo(umo: str) -> Path: + normalized_umo = normalize_umo_for_workspace(umo) + return Path(get_astrbot_workspaces_path()) / normalized_umo + + +def _apply_workspace_extra_prompt( + event: AstrMessageEvent, + req: ProviderRequest, +) -> None: + extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / ( + "EXTRA_PROMPT.md" + ) + if not extra_prompt_path.is_file(): + return + + try: + extra_prompt = extra_prompt_path.read_text(encoding="utf-8").strip() + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to read workspace extra prompt for umo=%s from %s: %s", + event.unified_msg_origin, + extra_prompt_path, + exc, + ) + return + + if not extra_prompt: + return + + req.system_prompt = ( + f"{req.system_prompt or ''}\n" + "[Workspace Extra Prompt]\n" + "The following instructions are loaded from the current workspace " + "`EXTRA_PROMPT.md` file.\n" + f"{extra_prompt}\n" + ) + + +def _apply_local_env_tools(req: ProviderRequest, plugin_context: Context) -> None: if req.func_tool is None: req.func_tool = ToolSet() - req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL) - req.func_tool.add_tool(LOCAL_PYTHON_TOOL) + tool_mgr = plugin_context.get_llm_tool_manager() + req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(LocalPythonTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool)) req.system_prompt = f"{req.system_prompt or ''}\n{_build_local_mode_prompt()}\n" @@ -765,6 +820,7 @@ async def _decorate_llm_request( if tz is None: tz = plugin_context.get_config().get("timezone") _append_system_reminders(event, req, cfg, tz) + _apply_workspace_extra_prompt(event, req) def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: @@ -981,7 +1037,9 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) - def _apply_sandbox_tools( - config: MainAgentBuildConfig, req: ProviderRequest, session_id: str + config: MainAgentBuildConfig, + req: ProviderRequest, + session_id: str, ) -> None: if req.func_tool is None: req.func_tool = ToolSet() @@ -997,10 +1055,15 @@ def _apply_sandbox_tools( os.environ["SHIPYARD_ENDPOINT"] = ep os.environ["SHIPYARD_ACCESS_TOKEN"] = at - req.func_tool.add_tool(EXECUTE_SHELL_TOOL) - req.func_tool.add_tool(PYTHON_TOOL) - req.func_tool.add_tool(FILE_UPLOAD_TOOL) - req.func_tool.add_tool(FILE_DOWNLOAD_TOOL) + tool_mgr = llm_tools + req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(PythonTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileUploadTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileDownloadTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool)) if booter == "shipyard_neo": # Neo-specific path rule: filesystem tools operate relative to sandbox # workspace root. Do not prepend "/workspace". @@ -1036,22 +1099,22 @@ def _apply_sandbox_tools( # Browser tools: only register if profile supports browser # (or if capabilities are unknown because sandbox hasn't booted yet) if sandbox_capabilities is None or "browser" in sandbox_capabilities: - req.func_tool.add_tool(BROWSER_EXEC_TOOL) - req.func_tool.add_tool(BROWSER_BATCH_EXEC_TOOL) - req.func_tool.add_tool(RUN_BROWSER_SKILL_TOOL) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserExecTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserBatchExecTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(RunBrowserSkillTool)) # Neo-specific tools (always available for shipyard_neo) - req.func_tool.add_tool(GET_EXECUTION_HISTORY_TOOL) - req.func_tool.add_tool(ANNOTATE_EXECUTION_TOOL) - req.func_tool.add_tool(CREATE_SKILL_PAYLOAD_TOOL) - req.func_tool.add_tool(GET_SKILL_PAYLOAD_TOOL) - req.func_tool.add_tool(CREATE_SKILL_CANDIDATE_TOOL) - req.func_tool.add_tool(LIST_SKILL_CANDIDATES_TOOL) - req.func_tool.add_tool(EVALUATE_SKILL_CANDIDATE_TOOL) - req.func_tool.add_tool(PROMOTE_SKILL_CANDIDATE_TOOL) - req.func_tool.add_tool(LIST_SKILL_RELEASES_TOOL) - req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL) - req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetExecutionHistoryTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(AnnotateExecutionTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillPayloadTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetSkillPayloadTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillCandidateTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillCandidatesTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(EvaluateSkillCandidateTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(PromoteSkillCandidateTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillReleasesTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(RollbackSkillReleaseTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(SyncSkillReleaseTool)) req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n" @@ -1341,7 +1404,7 @@ async def build_main_agent( if config.computer_use_runtime == "sandbox": _apply_sandbox_tools(config, req, req.session_id) elif config.computer_use_runtime == "local": - _apply_local_env_tools(req) + _apply_local_env_tools(req, plugin_context) agent_runner = AgentRunner() astr_agent_ctx = AstrAgentContext( @@ -1377,6 +1440,15 @@ async def build_main_agent( if config.tool_schema_mode == "full" else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE ) + + if config.computer_use_runtime == "local": + tool_prompt += ( + f"\nCurrent workspace you can use: " + f"`{_get_workspace_path_for_umo(event.unified_msg_origin)}`\n" + "Unless the user explicitly specifies a different directory, " + "perform all file-related operations in this workspace.\n" + ) + req.system_prompt += f"\n{tool_prompt}\n" action_type = event.get_extra("action_type") @@ -1402,6 +1474,14 @@ async def build_main_agent( fallback_providers=_get_fallback_chat_providers( provider, plugin_context, config.provider_settings ), + tool_result_overflow_dir=( + get_astrbot_system_tmp_path() + if req.func_tool and req.func_tool.get_tool("astrbot_file_read_tool") + else None + ), + read_tool=( + req.func_tool.get_tool("astrbot_file_read_tool") if req.func_tool else None + ), ) if apply_reset: diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 4d1e59c291..821ece702c 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -1,27 +1,5 @@ import base64 -from astrbot.core.computer.tools import ( - AnnotateExecutionTool, - BrowserBatchExecTool, - BrowserExecTool, - CreateSkillCandidateTool, - CreateSkillPayloadTool, - EvaluateSkillCandidateTool, - ExecuteShellTool, - FileDownloadTool, - FileUploadTool, - GetExecutionHistoryTool, - GetSkillPayloadTool, - ListSkillCandidatesTool, - ListSkillReleasesTool, - LocalPythonTool, - PromoteSkillCandidateTool, - PythonTool, - RollbackSkillReleaseTool, - RunBrowserSkillTool, - SyncSkillReleaseTool, -) - LLM_SAFETY_MODE_SYSTEM_PROMPT = """You are running in Safe Mode. Rules: @@ -130,28 +108,6 @@ "{background_task_result}" ) - -EXECUTE_SHELL_TOOL = ExecuteShellTool() -LOCAL_EXECUTE_SHELL_TOOL = ExecuteShellTool(is_local=True) -PYTHON_TOOL = PythonTool() -LOCAL_PYTHON_TOOL = LocalPythonTool() -FILE_UPLOAD_TOOL = FileUploadTool() -FILE_DOWNLOAD_TOOL = FileDownloadTool() -BROWSER_EXEC_TOOL = BrowserExecTool() -BROWSER_BATCH_EXEC_TOOL = BrowserBatchExecTool() -RUN_BROWSER_SKILL_TOOL = RunBrowserSkillTool() -GET_EXECUTION_HISTORY_TOOL = GetExecutionHistoryTool() -ANNOTATE_EXECUTION_TOOL = AnnotateExecutionTool() -CREATE_SKILL_PAYLOAD_TOOL = CreateSkillPayloadTool() -GET_SKILL_PAYLOAD_TOOL = GetSkillPayloadTool() -CREATE_SKILL_CANDIDATE_TOOL = CreateSkillCandidateTool() -LIST_SKILL_CANDIDATES_TOOL = ListSkillCandidatesTool() -EVALUATE_SKILL_CANDIDATE_TOOL = EvaluateSkillCandidateTool() -PROMOTE_SKILL_CANDIDATE_TOOL = PromoteSkillCandidateTool() -LIST_SKILL_RELEASES_TOOL = ListSkillReleasesTool() -ROLLBACK_SKILL_RELEASE_TOOL = RollbackSkillReleaseTool() -SYNC_SKILL_RELEASE_TOOL = SyncSkillReleaseTool() - # we prevent astrbot from connecting to known malicious hosts # these hosts are base64 encoded BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"} diff --git a/astrbot/core/computer/booters/boxlite.py b/astrbot/core/computer/booters/boxlite.py index 70064fdd48..aa3ca59761 100644 --- a/astrbot/core/computer/booters/boxlite.py +++ b/astrbot/core/computer/booters/boxlite.py @@ -4,7 +4,7 @@ import aiohttp import boxlite -from shipyard.filesystem import FileSystemComponent as ShipyardFileSystemComponent +from shipyard import FileSystemComponent as ShipyardFileSystemComponent from shipyard.python import PythonComponent as ShipyardPythonComponent from shipyard.shell import ShellComponent as ShipyardShellComponent @@ -12,6 +12,7 @@ from ..olayer import FileSystemComponent, PythonComponent, ShellComponent from .base import ComputerBooter +from .shipyard import ShipyardFileSystemWrapper class MockShipyardSandboxClient: @@ -150,21 +151,24 @@ async def boot(self, session_id: str) -> None: self.mocked = MockShipyardSandboxClient( sb_url=f"http://127.0.0.1:{random_port}" ) - self._fs = ShipyardFileSystemComponent( + self._python = ShipyardPythonComponent( client=self.mocked, # type: ignore ship_id=self.box.id, session_id=session_id, ) - self._python = ShipyardPythonComponent( + self._shell = ShipyardShellComponent( client=self.mocked, # type: ignore ship_id=self.box.id, session_id=session_id, ) - self._shell = ShipyardShellComponent( + self._ship_fs = ShipyardFileSystemComponent( client=self.mocked, # type: ignore ship_id=self.box.id, session_id=session_id, ) + self._fs = ShipyardFileSystemWrapper( + _shipyard_fs=self._ship_fs, _shipyard_shell=self._shell + ) await self.mocked.wait_healthy(self.box.id, session_id) diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index f11bc329fa..44122361d6 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -9,15 +9,18 @@ from dataclasses import dataclass from typing import Any +from python_ripgrep import search + from astrbot.api import logger -from astrbot.core.utils.astrbot_path import ( - get_astrbot_data_path, - get_astrbot_root, - get_astrbot_temp_path, +from astrbot.core.computer.file_read_utils import ( + detect_text_encoding, + read_local_text_range_sync, ) +from astrbot.core.utils.astrbot_path import get_astrbot_root from ..olayer import FileSystemComponent, PythonComponent, ShellComponent from .base import ComputerBooter +from .shipyard_search_file_util import _truncate_long_lines _BLOCKED_COMMAND_PATTERNS = [ " rm -rf ", @@ -41,18 +44,6 @@ def _is_safe_command(command: str) -> bool: return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS) -def _ensure_safe_path(path: str) -> str: - abs_path = os.path.abspath(path) - allowed_roots = [ - os.path.abspath(get_astrbot_root()), - os.path.abspath(get_astrbot_data_path()), - os.path.abspath(get_astrbot_temp_path()), - ] - if not any(abs_path.startswith(root) for root in allowed_roots): - raise PermissionError("Path is outside the allowed computer roots.") - return abs_path - - def _decode_bytes_with_fallback( output: bytes | None, *, @@ -110,7 +101,7 @@ def _run() -> dict[str, Any]: run_env = os.environ.copy() if env: run_env.update({str(k): str(v) for k, v in env.items()}) - working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root() + working_dir = os.path.abspath(cwd) if cwd else get_astrbot_root() if background: # `command` is intentionally executed through the current shell so # local computer-use behavior matches existing tool semantics. @@ -186,7 +177,7 @@ async def create_file( self, path: str, content: str = "", mode: int = 0o644 ) -> dict[str, Any]: def _run() -> dict[str, Any]: - abs_path = _ensure_safe_path(path) + abs_path = os.path.abspath(path) os.makedirs(os.path.dirname(abs_path), exist_ok=True) with open(abs_path, "w", encoding="utf-8") as f: f.write(content) @@ -195,16 +186,85 @@ def _run() -> dict[str, Any]: return await asyncio.to_thread(_run) - async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]: + async def read_file( + self, + path: str, + encoding: str = "utf-8", + offset: int | None = None, + limit: int | None = None, + ) -> dict[str, Any]: + def _run() -> dict[str, Any]: + abs_path = os.path.abspath(path) + detected_encoding = encoding + if encoding == "utf-8": + with open(abs_path, "rb") as f: + raw_sample = f.read(8192) + detected_encoding = detect_text_encoding(raw_sample) or encoding + return { + "success": True, + "content": read_local_text_range_sync( + abs_path, + encoding=detected_encoding, + offset=offset, + limit=limit, + ), + } + + return await asyncio.to_thread(_run) + + async def search_files( + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + after_context: int | None = None, + before_context: int | None = None, + ) -> dict[str, Any]: def _run() -> dict[str, Any]: - abs_path = _ensure_safe_path(path) - with open(abs_path, "rb") as f: - raw_content = f.read() - content = _decode_bytes_with_fallback( - raw_content, - preferred_encoding=encoding, + results = search( + patterns=[pattern], + paths=[path] if path else None, + globs=[glob] if glob else None, + after_context=after_context, + before_context=before_context, + line_number=True, ) - return {"success": True, "content": content} + return {"success": True, "content": _truncate_long_lines("".join(results))} + + return await asyncio.to_thread(_run) + + async def edit_file( + self, + path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + encoding: str = "utf-8", + ) -> dict[str, Any]: + def _run() -> dict[str, Any]: + abs_path = os.path.abspath(path) + with open(abs_path, encoding=encoding) as f: + content = f.read() + occurrences = content.count(old_string) + if occurrences == 0: + return { + "success": False, + "error": "old string not found in file", + "replacements": 0, + } + if replace_all: + updated = content.replace(old_string, new_string) + replacements = occurrences + else: + updated = content.replace(old_string, new_string, 1) + replacements = 1 + with open(abs_path, "w", encoding=encoding) as f: + f.write(updated) + return { + "success": True, + "path": abs_path, + "replacements": replacements, + } return await asyncio.to_thread(_run) @@ -212,7 +272,7 @@ async def write_file( self, path: str, content: str, mode: str = "w", encoding: str = "utf-8" ) -> dict[str, Any]: def _run() -> dict[str, Any]: - abs_path = _ensure_safe_path(path) + abs_path = os.path.abspath(path) os.makedirs(os.path.dirname(abs_path), exist_ok=True) with open(abs_path, mode, encoding=encoding) as f: f.write(content) @@ -222,7 +282,7 @@ def _run() -> dict[str, Any]: async def delete_file(self, path: str) -> dict[str, Any]: def _run() -> dict[str, Any]: - abs_path = _ensure_safe_path(path) + abs_path = os.path.abspath(path) if os.path.isdir(abs_path): shutil.rmtree(abs_path) else: @@ -235,7 +295,7 @@ async def list_dir( self, path: str = ".", show_hidden: bool = False ) -> dict[str, Any]: def _run() -> dict[str, Any]: - abs_path = _ensure_safe_path(path) + abs_path = os.path.abspath(path) entries = os.listdir(abs_path) if not show_hidden: entries = [e for e in entries if not e.startswith(".")] diff --git a/astrbot/core/computer/booters/shipyard.py b/astrbot/core/computer/booters/shipyard.py index 6379d1e48b..bed1b06531 100644 --- a/astrbot/core/computer/booters/shipyard.py +++ b/astrbot/core/computer/booters/shipyard.py @@ -1,9 +1,87 @@ +from __future__ import annotations + +from typing import Any + +from shipyard import FileSystemComponent as ShipyardFileSystemComponent from shipyard import ShipyardClient, Spec from astrbot.api import logger from ..olayer import FileSystemComponent, PythonComponent, ShellComponent from .base import ComputerBooter +from .shipyard_search_file_util import search_files_via_shell + + +class ShipyardFileSystemWrapper: + def __init__( + self, _shipyard_fs: ShipyardFileSystemComponent, _shipyard_shell: ShellComponent + ): + self._fs = _shipyard_fs + self._shell = _shipyard_shell + + async def create_file( + self, path: str, content: str = "", mode: int = 420 + ) -> dict[str, Any]: + return await self._fs.create_file(path=path, content=content, mode=mode) + + async def read_file( + self, + path: str, + encoding: str = "utf-8", + offset: int | None = None, + limit: int | None = None, + ) -> dict[str, Any]: + return await self._fs.read_file( + path=path, encoding=encoding, offset=offset, limit=limit + ) + + async def write_file( + self, path: str, content: str, mode: str = "w", encoding: str = "utf-8" + ) -> dict[str, Any]: + return await self._fs.write_file( + path=path, content=content, mode=mode, encoding=encoding + ) + + async def list_dir( + self, path: str = ".", show_hidden: bool = False + ) -> dict[str, Any]: + return await self._fs.list_dir(path=path, show_hidden=show_hidden) + + async def delete_file(self, path: str) -> dict[str, Any]: + return await self._fs.delete_file(path=path) + + async def search_files( + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + after_context: int | None = None, + before_context: int | None = None, + ) -> dict[str, Any]: + return await search_files_via_shell( + self._shell, + pattern=pattern, + path=path, + glob=glob, + after_context=after_context, + before_context=before_context, + ) + + async def edit_file( + self, + path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + encoding: str = "utf-8", + ) -> dict[str, Any]: + return await self._fs.edit_file( + path=path, + old_string=old_string, + new_string=new_string, + replace_all=replace_all, + encoding=encoding, + ) class ShipyardBooter(ComputerBooter): @@ -29,13 +107,14 @@ async def boot(self, session_id: str) -> None: ) logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}") self._ship = ship + self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._ship.shell) async def shutdown(self) -> None: logger.info("[Computer] Shipyard booter shutdown.") @property def fs(self) -> FileSystemComponent: - return self._ship.fs + return self._fs @property def python(self) -> PythonComponent: diff --git a/astrbot/core/computer/booters/shipyard_neo.py b/astrbot/core/computer/booters/shipyard_neo.py index 6304696ad2..f3bb7e7b53 100644 --- a/astrbot/core/computer/booters/shipyard_neo.py +++ b/astrbot/core/computer/booters/shipyard_neo.py @@ -13,6 +13,15 @@ ShellComponent, ) from .base import ComputerBooter +from .shipyard_search_file_util import search_files_via_shell + +try: + from shipyard_neo import BayClient + from shipyard_neo.sandbox import Sandbox +except ImportError: + logger.warning( + "shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it." + ) def _maybe_model_dump(value: Any) -> dict[str, Any]: @@ -25,8 +34,20 @@ def _maybe_model_dump(value: Any) -> dict[str, Any]: return {} +def _slice_content_by_lines( + content: str, + *, + offset: int | None = None, + limit: int | None = None, +) -> str: + lines = content.splitlines(keepends=True) + start = 0 if offset is None else offset + selected = lines[start:] if limit is None else lines[start : start + limit] + return "".join(selected) + + class NeoPythonComponent(PythonComponent): - def __init__(self, sandbox: Any) -> None: + def __init__(self, sandbox: Sandbox) -> None: self._sandbox = sandbox async def exec( @@ -67,7 +88,7 @@ async def exec( class NeoShellComponent(ShellComponent): - def __init__(self, sandbox: Any) -> None: + def __init__(self, sandbox: Sandbox) -> None: self._sandbox = sandbox async def exec( @@ -136,8 +157,9 @@ async def exec( class NeoFileSystemComponent(FileSystemComponent): - def __init__(self, sandbox: Any) -> None: + def __init__(self, sandbox: Sandbox, shell: ShellComponent) -> None: self._sandbox = sandbox + self._shell = shell async def create_file( self, @@ -149,10 +171,71 @@ async def create_file( await self._sandbox.filesystem.write_file(path, content) return {"success": True, "path": path} - async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]: + async def read_file( + self, + path: str, + encoding: str = "utf-8", + offset: int | None = None, + limit: int | None = None, + ) -> dict[str, Any]: + _ = encoding + content = await self._sandbox.filesystem.read_file(path) + return { + "success": True, + "path": path, + "content": _slice_content_by_lines( + content, + offset=offset, + limit=limit, + ), + } + + async def search_files( + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + after_context: int | None = None, + before_context: int | None = None, + ) -> dict[str, Any]: + return await search_files_via_shell( + self._shell, + pattern=pattern, + path=path, + glob=glob, + after_context=after_context, + before_context=before_context, + ) + + async def edit_file( + self, + path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + encoding: str = "utf-8", + ) -> dict[str, Any]: _ = encoding content = await self._sandbox.filesystem.read_file(path) - return {"success": True, "path": path, "content": content} + occurrences = content.count(old_string) + if occurrences == 0: + return { + "success": False, + "error": "old string not found in file", + "replacements": 0, + } + if replace_all: + updated = content.replace(old_string, new_string) + replacements = occurrences + else: + updated = content.replace(old_string, new_string, 1) + replacements = 1 + await self._sandbox.filesystem.write_file(path, updated) + return { + "success": True, + "path": path, + "replacements": replacements, + } async def write_file( self, @@ -186,7 +269,7 @@ async def list_dir( class NeoBrowserComponent(BrowserComponent): - def __init__(self, sandbox: Any) -> None: + def __init__(self, sandbox: Sandbox) -> None: self._sandbox = sandbox async def exec( @@ -271,8 +354,8 @@ def __init__( self._access_token = access_token self._profile = profile self._ttl = ttl - self._client: Any = None - self._sandbox: Any = None + self._client: BayClient | None = None + self._sandbox: Sandbox | None = None self._bay_manager: Any = None # BayContainerManager when auto-started self._fs: FileSystemComponent | None = None self._python: PythonComponent | None = None @@ -336,8 +419,6 @@ async def boot(self, session_id: str) -> None: "or ensure Bay's credentials.json is accessible for auto-discovery." ) - from shipyard_neo import BayClient - self._client = BayClient( endpoint_url=self._endpoint_url, access_token=self._access_token, @@ -352,9 +433,9 @@ async def boot(self, session_id: str) -> None: ttl=self._ttl, ) - self._fs = NeoFileSystemComponent(self._sandbox) - self._python = NeoPythonComponent(self._sandbox) self._shell = NeoShellComponent(self._sandbox) + self._fs = NeoFileSystemComponent(self._sandbox, self._shell) + self._python = NeoPythonComponent(self._sandbox) caps = self.capabilities or () self._browser = ( diff --git a/astrbot/core/computer/booters/shipyard_search_file_util.py b/astrbot/core/computer/booters/shipyard_search_file_util.py new file mode 100644 index 0000000000..1227244de3 --- /dev/null +++ b/astrbot/core/computer/booters/shipyard_search_file_util.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import shlex +from typing import Any + +from ..olayer import ShellComponent + +_MAX_SEARCH_LINE_COLUMNS = 1000 + + +def _truncate_long_lines(text: str) -> str: + output_lines: list[str] = [] + for line in text.splitlines(keepends=True): + line_ending = "" + line_body = line + if line.endswith("\r\n"): + line_body = line[:-2] + line_ending = "\r\n" + elif line.endswith("\n") or line.endswith("\r"): + line_body = line[:-1] + line_ending = line[-1] + + if len(line_body) > _MAX_SEARCH_LINE_COLUMNS: + line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS] + + output_lines.append(f"{line_body}{line_ending}") + return "".join(output_lines) + + +def _build_rg_command( + *, + pattern: str, + path: str, + glob: str | None, + after_context: int | None, + before_context: int | None, +) -> list[str]: + command = [ + "rg", + "--color=never", + "-n", + "--max-columns", + str(_MAX_SEARCH_LINE_COLUMNS), + "-e", + pattern, + ] + if glob: + command.extend(["-g", glob]) + if after_context is not None: + command.extend(["-A", str(after_context)]) + if before_context is not None: + command.extend(["-B", str(before_context)]) + command.extend(["--", path]) + return command + + +def _build_grep_command( + *, + pattern: str, + path: str, + glob: str | None, + after_context: int | None, + before_context: int | None, +) -> list[str]: + command = ["grep", "-R", "-H", "-n", "-e", pattern] + if glob: + command.append(f"--include={glob}") + if after_context is not None: + command.extend(["-A", str(after_context)]) + if before_context is not None: + command.extend(["-B", str(before_context)]) + command.extend(["--", path]) + return command + + +def _quote_command(command: list[str]) -> str: + return " ".join(shlex.quote(part) for part in command) + + +def build_search_command( + *, + pattern: str, + path: str, + glob: str | None, + after_context: int | None, + before_context: int | None, +) -> str: + rg_command = _quote_command( + _build_rg_command( + pattern=pattern, + path=path, + glob=glob, + after_context=after_context, + before_context=before_context, + ) + ) + grep_command = _quote_command( + _build_grep_command( + pattern=pattern, + path=path, + glob=glob, + after_context=after_context, + before_context=before_context, + ) + ) + return ( + "if command -v rg >/dev/null 2>&1; then " + f"{rg_command}; " + "elif command -v grep >/dev/null 2>&1; then " + f"{grep_command}; " + "else " + "echo 'Neither rg nor grep is available in the sandbox.' >&2; " + "exit 127; " + "fi" + ) + + +async def search_files_via_shell( + shell: ShellComponent, + *, + pattern: str, + path: str | None = None, + glob: str | None = None, + after_context: int | None = None, + before_context: int | None = None, + timeout: int = 30, +) -> dict[str, Any]: + command = build_search_command( + pattern=pattern, + path=path or ".", + glob=glob, + after_context=after_context, + before_context=before_context, + ) + result = await shell.exec(command, timeout=timeout) + stdout = _truncate_long_lines(str(result.get("stdout", "") or "")) + stderr = str(result.get("stderr", "") or "") + exit_code = result.get("exit_code") + if exit_code in (0, None): + return {"success": True, "content": stdout} + if exit_code == 1: + return {"success": True, "content": ""} + return { + "success": False, + "content": "", + "error": stderr or f"command exited with code {exit_code}", + "exit_code": exit_code, + } diff --git a/astrbot/core/computer/file_read_utils.py b/astrbot/core/computer/file_read_utils.py new file mode 100644 index 0000000000..0f4d0811cf --- /dev/null +++ b/astrbot/core/computer/file_read_utils.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +import base64 +import hashlib +import io +import json +import zipfile +from asyncio import to_thread +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import mcp + +from astrbot.core.agent.context.token_counter import EstimateTokenCounter +from astrbot.core.agent.message import Message +from astrbot.core.agent.tool import ToolExecResult +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.media_utils import ( + IMAGE_COMPRESS_DEFAULT_MAX_SIZE, + IMAGE_COMPRESS_DEFAULT_OPTIMIZE, + IMAGE_COMPRESS_DEFAULT_QUALITY, + _compress_image_sync, +) + +from .booters.base import ComputerBooter + +_MAX_FILE_READ_BYTES = 128 * 1024 +_MAX_FILE_READ_TOKENS = 25_000 +_MAX_TEXT_FILE_FULL_READ_BYTES = 256 * 1024 +_FILE_SNIFF_BYTES = 512 +_TOKEN_COUNTER = EstimateTokenCounter() +_TEXT_ENCODINGS = ( + "utf-8-sig", + "utf-8", + "gb18030", + "utf-16", + "utf-16-le", + "utf-16-be", + "utf-32", + "utf-32-le", + "utf-32-be", +) +_UTF_BOMS = ( + b"\xef\xbb\xbf", + b"\xff\xfe", + b"\xfe\xff", + b"\xff\xfe\x00\x00", + b"\x00\x00\xfe\xff", +) +_ZIP_MAGIC_PREFIXES = ( + b"PK\x03\x04", + b"PK\x05\x06", + b"PK\x07\x08", +) +_BINARY_MAGIC_PREFIXES = ( + b"%PDF-", + b"\x1f\x8b", + b"7z\xbc\xaf\x27\x1c", + b"Rar!\x1a\x07", + b"\x7fELF", + b"MZ", +) + + +@dataclass(frozen=True) +class FileProbe: + kind: Literal["text", "image", "binary"] + encoding: str | None + mime_type: str | None + size_bytes: int + + +@dataclass(frozen=True) +class ParsedDocument: + kind: Literal["docx", "pdf"] + file_bytes: bytes + text: str + + +def _build_probe_script(path: str) -> str: + return f""" +import base64 +import json +from pathlib import Path + +path = Path({path!r}) +with path.open("rb") as file_obj: + sample = file_obj.read({_FILE_SNIFF_BYTES}) +print( + json.dumps( + {{ + "size_bytes": path.stat().st_size, + "sample_b64": base64.b64encode(sample).decode("ascii"), + }} + ) +) +""".strip() + + +def _build_text_read_script( + path: str, + *, + encoding: str, + offset: int | None, + limit: int | None, +) -> str: + start_expr = "0" if offset is None else str(offset) + limit_expr = "None" if limit is None else str(limit) + return f""" +import json +from pathlib import Path + +path = Path({path!r}) +start = {start_expr} +limit = {limit_expr} +end = None if limit is None else start + limit +lines = [] +with path.open("r", encoding={encoding!r}, newline="") as file_obj: + for index, line in enumerate(file_obj): + if index < start: + continue + if end is not None and index >= end: + break + lines.append(line) +content = "".join(lines) +print(json.dumps({{"content": content}}, ensure_ascii=False)) +""".strip() + + +def _build_image_read_script(path: str) -> str: + return f""" +import base64 +import json +from pathlib import Path + +path = Path({path!r}) +data = path.read_bytes() +print( + json.dumps( + {{ + "size_bytes": len(data), + "base64": base64.b64encode(data).decode("ascii"), + }} + ) +) +""".strip() + + +def _looks_like_text(decoded: str) -> bool: + if not decoded: + return True + + disallowed = 0 + printable = 0 + for char in decoded: + if char in "\n\r\t\f\b": + printable += 1 + continue + if char.isprintable(): + printable += 1 + code = ord(char) + if (0 <= code < 32) or (127 <= code < 160): + disallowed += 1 + + total = max(len(decoded), 1) + return disallowed / total <= 0.02 and printable / total >= 0.85 + + +def detect_text_encoding(sample: bytes) -> str | None: + if not sample: + return "utf-8" + + if b"\x00" in sample and not sample.startswith(_UTF_BOMS): + odd_bytes = sample[1::2] + even_bytes = sample[0::2] + odd_zero_ratio = odd_bytes.count(0) / max(len(odd_bytes), 1) + even_zero_ratio = even_bytes.count(0) / max(len(even_bytes), 1) + if odd_zero_ratio < 0.8 and even_zero_ratio < 0.8: + return None + + for encoding in _TEXT_ENCODINGS: + try: + decoded = sample.decode(encoding) + except UnicodeDecodeError as exc: + # Probe samples can end in the middle of a multibyte sequence. + # When the decode failure only happens at the sample tail, trim a few + # bytes and retry so UTF-8 text is not misclassified as binary. + if exc.start >= len(sample) - 4: + decoded = "" + for trim_bytes in range(1, min(4, len(sample)) + 1): + try: + decoded = sample[:-trim_bytes].decode(encoding) + break + except UnicodeDecodeError: + continue + if not decoded: + continue + else: + continue + if _looks_like_text(decoded): + return encoding + + return None + + +def read_local_text_range_sync( + path: str, + *, + encoding: str, + offset: int | None, + limit: int | None, +) -> str: + lines: list[str] = [] + start = 0 if offset is None else offset + end = None if limit is None else start + limit + with open(path, encoding=encoding, newline="") as file_obj: + for index, line in enumerate(file_obj): + if index < start: + continue + if end is not None and index >= end: + break + lines.append(line) + return "".join(lines) + + +async def read_local_text_range( + path: str, + *, + encoding: str, + offset: int | None, + limit: int | None, +) -> str: + return await to_thread( + read_local_text_range_sync, + path, + encoding=encoding, + offset=offset, + limit=limit, + ) + + +async def _exec_python_json( + booter: ComputerBooter, + script: str, + *, + action: str, +) -> dict: + result = await booter.python.exec(script) + data = result.get("data") if isinstance(result.get("data"), dict) else {} + if not isinstance(data, dict): + raise RuntimeError(f"{action} failed: invalid result format") + output = data.get("output") if isinstance(data.get("output"), dict) else {} + if not isinstance(output, dict): + raise RuntimeError(f"{action} failed: invalid output format") + error_text = str(data.get("error", "") or result.get("error", "") or "").strip() + if error_text: + raise RuntimeError(f"{action} failed: {error_text}") + + text = str(output.get("text", "") or "").strip() + if not text: + raise RuntimeError(f"{action} failed: empty output") + + try: + payload = json.loads(text) + except json.JSONDecodeError as exc: + raise RuntimeError(f"{action} failed: invalid JSON output") from exc + + if not isinstance(payload, dict): + raise RuntimeError(f"{action} failed: invalid JSON payload") + return payload + + +async def _probe_local_file(path: str) -> dict[str, str | int]: + def _run() -> dict[str, str | int]: + file_path = Path(path) + with file_path.open("rb") as file_obj: + sample = file_obj.read(_FILE_SNIFF_BYTES) + return { + "size_bytes": file_path.stat().st_size, + "sample_b64": base64.b64encode(sample).decode("ascii"), + } + + return await to_thread(_run) + + +async def _read_local_image_base64(path: str) -> dict[str, str | int]: + def _run() -> dict[str, str | int]: + data = Path(path).read_bytes() + return { + "size_bytes": len(data), + "base64": base64.b64encode(data).decode("ascii"), + } + + return await to_thread(_run) + + +async def _read_local_file_bytes(path: str) -> bytes: + return await to_thread(Path(path).read_bytes) + + +async def _compress_image_bytes_to_base64(data: bytes) -> dict[str, str | int]: + def _run() -> dict[str, str | int]: + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + compressed_path = Path( + _compress_image_sync( + data, + temp_dir, + IMAGE_COMPRESS_DEFAULT_MAX_SIZE, + IMAGE_COMPRESS_DEFAULT_QUALITY, + IMAGE_COMPRESS_DEFAULT_OPTIMIZE, + ) + ) + try: + compressed_bytes = compressed_path.read_bytes() + finally: + compressed_path.unlink(missing_ok=True) + + return { + "size_bytes": len(compressed_bytes), + "base64": base64.b64encode(compressed_bytes).decode("ascii"), + "mime_type": "image/jpeg", + } + + return await to_thread(_run) + + +def _detect_image_mime(sample: bytes) -> str | None: + if sample.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if sample.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if sample.startswith((b"GIF87a", b"GIF89a")): + return "image/gif" + if sample.startswith(b"BM"): + return "image/bmp" + if sample.startswith((b"II*\x00", b"MM\x00*")): + return "image/tiff" + if sample.startswith(b"\x00\x00\x01\x00"): + return "image/x-icon" + if len(sample) >= 12 and sample[:4] == b"RIFF" and sample[8:12] == b"WEBP": + return "image/webp" + if len(sample) >= 12 and sample[4:12] in (b"ftypavif", b"ftypavis"): + return "image/avif" + return None + + +def _looks_like_known_binary(sample: bytes) -> bool: + return any(sample.startswith(prefix) for prefix in _BINARY_MAGIC_PREFIXES) + + +def _looks_like_pdf(path: str, sample: bytes) -> bool: + return Path(path).suffix.lower() == ".pdf" or sample.startswith(b"%PDF-") + + +def _looks_like_zip_container(sample: bytes) -> bool: + return any(sample.startswith(prefix) for prefix in _ZIP_MAGIC_PREFIXES) + + +def _is_docx_bytes(file_bytes: bytes) -> bool: + try: + with zipfile.ZipFile(io.BytesIO(file_bytes)) as archive: + names = set(archive.namelist()) + except (OSError, zipfile.BadZipFile): + return False + + if "[Content_Types].xml" not in names: + return False + + return any(name.startswith("word/") for name in names) + + +async def _parse_local_docx_text(file_bytes: bytes, file_name: str) -> str: + from astrbot.core.knowledge_base.parsers.markitdown_parser import ( + MarkitdownParser, + ) + + result = await MarkitdownParser().parse(file_bytes, file_name) + return result.text + + +async def _parse_local_pdf_text(file_bytes: bytes, file_name: str) -> str: + from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser + + result = await PDFParser().parse(file_bytes, file_name) + return result.text + + +async def _parse_local_supported_document( + path: str, + sample: bytes, +) -> ParsedDocument | None: + file_name = Path(path).name + if _looks_like_pdf(path, sample): + file_bytes = await _read_local_file_bytes(path) + text = await _parse_local_pdf_text(file_bytes, file_name) + return ParsedDocument(kind="pdf", file_bytes=file_bytes, text=text) + + if Path(path).suffix.lower() == ".docx" or _looks_like_zip_container(sample): + file_bytes = await _read_local_file_bytes(path) + if not _is_docx_bytes(file_bytes): + return None + text = await _parse_local_docx_text(file_bytes, file_name) + return ParsedDocument(kind="docx", file_bytes=file_bytes, text=text) + + return None + + +def _probe_file(sample: bytes, *, size_bytes: int) -> FileProbe: + if image_mime := _detect_image_mime(sample): + return FileProbe( + kind="image", + encoding=None, + mime_type=image_mime, + size_bytes=size_bytes, + ) + + if _looks_like_known_binary(sample): + return FileProbe( + kind="binary", + encoding=None, + mime_type=None, + size_bytes=size_bytes, + ) + + if encoding := detect_text_encoding(sample): + return FileProbe( + kind="text", + encoding=encoding, + mime_type="text/plain", + size_bytes=size_bytes, + ) + + return FileProbe( + kind="binary", + encoding=None, + mime_type=None, + size_bytes=size_bytes, + ) + + +def _validate_text_output(content: str) -> str | None: + content_bytes = len(content.encode("utf-8")) + if content_bytes > _MAX_FILE_READ_BYTES: + return ( + "Error reading file: " + f"output exceeds {_MAX_FILE_READ_BYTES} bytes " + f"({content_bytes} bytes). Use `offset`, `limit` to narrow the read window." + ) + + content_tokens = _TOKEN_COUNTER.count_tokens( + [Message(role="user", content=content)] + ) + if content_tokens > _MAX_FILE_READ_TOKENS: + return ( + "Error reading file: " + f"output exceeds {_MAX_FILE_READ_TOKENS} tokens " + f"({content_tokens} tokens). Use `offset`, `limit` to narrow the read window." + ) + + return None + + +def _text_exceeds_read_thresholds(content: str) -> bool: + return _validate_text_output(content) is not None + + +def _validate_full_text_read_request(probe: FileProbe) -> str | None: + if probe.size_bytes > _MAX_TEXT_FILE_FULL_READ_BYTES: + return ( + "Error reading file: " + f"text file exceeds {_MAX_TEXT_FILE_FULL_READ_BYTES} bytes " + f"({probe.size_bytes} bytes). Use `offset` and `limit` to narrow the read window." + ) + return None + + +def _slice_text_by_lines( + content: str, + *, + offset: int | None, + limit: int | None, +) -> str: + if offset is None and limit is None: + return content + + lines = content.splitlines(keepends=True) + start = 0 if offset is None else offset + end = None if limit is None else start + limit + return "".join(lines[start:end]) + + +async def _store_converted_text_for_workspace( + *, + workspace_dir: str, + original_path: str, + original_bytes: bytes, + content: str, +) -> str: + def _run() -> str: + original_name = Path(original_path).name + digest_suffix = hashlib.md5(original_bytes).hexdigest()[-6:] + target_dir = ( + Path(workspace_dir) / "converted_files" / f"{original_name}_{digest_suffix}" + ) + target_dir.mkdir(parents=True, exist_ok=True) + target_path = target_dir / "text.txt" + target_path.write_text(content, encoding="utf-8") + return str(target_path) + + return await to_thread(_run) + + +def _build_converted_text_notice( + converted_text_path: str, + *, + selection_returned: bool, + selection_too_large: bool = False, +) -> str: + if selection_too_large: + return ( + "Converted text was saved to " + f"`{converted_text_path}`. The requested output is still too large to " + "return directly. Read or grep that file with a narrower window." + ) + + if selection_returned: + return ( + "Full converted text is also available at " + f"`{converted_text_path}`. Read or grep that file with a narrow " + "window for additional reads." + ) + + return ( + "Converted text was saved to " + f"`{converted_text_path}` because the parsed document is too large to " + "return directly. Read or grep that file with a narrow window." + ) + + +async def _read_local_supported_document_result( + *, + path: str, + parsed_document: ParsedDocument, + workspace_dir: str | None, + offset: int | None, + limit: int | None, +) -> ToolExecResult: + content = parsed_document.text + if not content: + return "No content found at the requested line offset." + + if not _text_exceeds_read_thresholds(content): + selected_content = _slice_text_by_lines(content, offset=offset, limit=limit) + if not selected_content: + return "No content found at the requested line offset." + if validation_error := _validate_text_output(selected_content): + return validation_error + return selected_content + + if not workspace_dir: + return ( + "Error reading file: parsed document exceeds the read output limit and " + "no workspace is available for storing converted text." + ) + + converted_text_path = await _store_converted_text_for_workspace( + workspace_dir=workspace_dir, + original_path=path, + original_bytes=parsed_document.file_bytes, + content=content, + ) + + if offset is None and limit is None: + return _build_converted_text_notice( + converted_text_path, + selection_returned=False, + ) + + selected_content = _slice_text_by_lines(content, offset=offset, limit=limit) + if not selected_content: + return ( + "No content found at the requested line offset. " + + _build_converted_text_notice( + converted_text_path, + selection_returned=False, + ) + ) + + notice = _build_converted_text_notice( + converted_text_path, + selection_returned=True, + ) + combined_output = f"{selected_content}\n\n[{notice}]" + if _validate_text_output(combined_output): + if _validate_text_output(selected_content): + return _build_converted_text_notice( + converted_text_path, + selection_returned=False, + selection_too_large=True, + ) + return selected_content + + return combined_output + + +async def read_file_tool_result( + booter: ComputerBooter, + *, + local_mode: bool, + path: str, + offset: int | None, + limit: int | None, + workspace_dir: str | None = None, +) -> ToolExecResult: + if local_mode: + probe_payload = await _probe_local_file(path) + else: + probe_payload = await _exec_python_json( + booter, + _build_probe_script(path), + action="file probe", + ) + sample_b64 = str(probe_payload.get("sample_b64", "") or "") + sample = base64.b64decode(sample_b64) if sample_b64 else b"" + size_bytes = int(probe_payload.get("size_bytes", 0) or 0) + probe = _probe_file(sample, size_bytes=size_bytes) + + if local_mode: + try: + parsed_document = await _parse_local_supported_document(path, sample) + except Exception as exc: + return f"Error reading file: failed to parse document: {exc}" + + if parsed_document is not None: + return await _read_local_supported_document_result( + path=path, + parsed_document=parsed_document, + workspace_dir=workspace_dir, + offset=offset, + limit=limit, + ) + + if probe.kind == "binary": + return "Error reading file: binary files are not supported by this tool." + + if probe.kind == "image": + if local_mode: + image_payload = await _read_local_image_base64(path) + else: + image_payload = await _exec_python_json( + booter, + _build_image_read_script(path), + action="image read", + ) + raw_base64_data = str(image_payload.get("base64", "") or "") + if not raw_base64_data: + return "Error reading file: image payload is empty." + raw_bytes = base64.b64decode(raw_base64_data) + compressed_payload = await _compress_image_bytes_to_base64(raw_bytes) + base64_data = str(compressed_payload.get("base64", "") or "") + if not base64_data: + return "Error reading file: compressed image payload is empty." + return mcp.types.CallToolResult( + content=[ + mcp.types.ImageContent( + type="image", + data=base64_data, + mimeType=str( + compressed_payload.get("mime_type", "") or "image/jpeg" + ), + ) + ] + ) + + if offset is None and limit is None: + if validation_error := _validate_full_text_read_request(probe): + return validation_error + + if local_mode: + content = await read_local_text_range( + path, + encoding=probe.encoding or "utf-8", + offset=offset, + limit=limit, + ) + else: + text_payload = await _exec_python_json( + booter, + _build_text_read_script( + path, + encoding=probe.encoding or "utf-8", + offset=offset, + limit=limit, + ), + action="text read", + ) + content = str(text_payload.get("content", "") or "") + + if not content: + return "No content found at the requested line offset." + + if validation_error := _validate_text_output(content): + return validation_error + + return content diff --git a/astrbot/core/computer/olayer/filesystem.py b/astrbot/core/computer/olayer/filesystem.py index 21f36d1110..04df566b1f 100644 --- a/astrbot/core/computer/olayer/filesystem.py +++ b/astrbot/core/computer/olayer/filesystem.py @@ -12,8 +12,36 @@ async def create_file( """Create a file with the specified content""" ... - async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]: - """Read file content""" + async def read_file( + self, + path: str, + encoding: str = "utf-8", + offset: int | None = None, + limit: int | None = None, + ) -> dict[str, Any]: + """Read file content by line window""" + ... + + async def search_files( + self, + pattern: str, + path: str | None = None, + glob: str | None = None, + after_context: int | None = None, + before_context: int | None = None, + ) -> dict[str, Any]: + """Search file contents""" + ... + + async def edit_file( + self, + path: str, + old_string: str, + new_string: str, + replace_all: bool = False, + encoding: str = "utf-8", + ) -> dict[str, Any]: + """Edit file content by string replacement""" ... async def write_file( diff --git a/astrbot/core/computer/tools/fs.py b/astrbot/core/computer/tools/fs.py deleted file mode 100644 index f2a698f763..0000000000 --- a/astrbot/core/computer/tools/fs.py +++ /dev/null @@ -1,213 +0,0 @@ -import os -import uuid -from dataclasses import dataclass, field - -from astrbot.api import FunctionTool, logger -from astrbot.api.event import MessageChain -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.message.components import File -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..computer_client import get_booter -from .permissions import check_admin_permission - -# @dataclass -# class CreateFileTool(FunctionTool): -# name: str = "astrbot_create_file" -# description: str = "Create a new file in the sandbox." -# parameters: dict = field( -# default_factory=lambda: { -# "type": "object", -# "properties": { -# "path": { -# "path": "string", -# "description": "The path where the file should be created, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.", -# }, -# "content": { -# "type": "string", -# "description": "The content to write into the file.", -# }, -# }, -# "required": ["path", "content"], -# } -# ) - -# async def call( -# self, context: ContextWrapper[AstrAgentContext], path: str, content: str -# ) -> ToolExecResult: -# sb = await get_booter( -# context.context.context, -# context.context.event.unified_msg_origin, -# ) -# try: -# result = await sb.fs.create_file(path, content) -# return json.dumps(result) -# except Exception as e: -# return f"Error creating file: {str(e)}" - - -# @dataclass -# class ReadFileTool(FunctionTool): -# name: str = "astrbot_read_file" -# description: str = "Read the content of a file in the sandbox." -# parameters: dict = field( -# default_factory=lambda: { -# "type": "object", -# "properties": { -# "path": { -# "type": "string", -# "description": "The path of the file to read, relative to the sandbox root. Must not use absolute paths or traverse outside the sandbox.", -# }, -# }, -# "required": ["path"], -# } -# ) - -# async def call(self, context: ContextWrapper[AstrAgentContext], path: str): -# sb = await get_booter( -# context.context.context, -# context.context.event.unified_msg_origin, -# ) -# try: -# result = await sb.fs.read_file(path) -# return result -# except Exception as e: -# return f"Error reading file: {str(e)}" - - -@dataclass -class FileUploadTool(FunctionTool): - name: str = "astrbot_upload_file" - description: str = ( - "Transfer a file FROM the host machine INTO the sandbox so that sandbox " - "code can access it. Use this when the user sends/attaches a file and you " - "need to process it inside the sandbox. The local_path must point to an " - "existing file on the host filesystem." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "local_path": { - "type": "string", - "description": "Absolute path to the file on the host filesystem that will be copied into the sandbox.", - }, - # "remote_path": { - # "type": "string", - # "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.", - # }, - }, - "required": ["local_path"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - local_path: str, - ) -> str | None: - if permission_error := check_admin_permission(context, "File upload/download"): - return permission_error - sb = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - try: - # Check if file exists - if not os.path.exists(local_path): - return f"Error: File does not exist: {local_path}" - - if not os.path.isfile(local_path): - return f"Error: Path is not a file: {local_path}" - - # Use basename if sandbox_filename is not provided - remote_path = os.path.basename(local_path) - - # Upload file to sandbox - result = await sb.upload_file(local_path, remote_path) - logger.debug(f"Upload result: {result}") - success = result.get("success", False) - - if not success: - return f"Error uploading file: {result.get('message', 'Unknown error')}" - - file_path = result.get("file_path", "") - logger.info(f"File {local_path} uploaded to sandbox at {file_path}") - - return f"File uploaded successfully to {file_path}" - except Exception as e: - logger.error(f"Error uploading file {local_path}: {e}") - return f"Error uploading file: {str(e)}" - - -@dataclass -class FileDownloadTool(FunctionTool): - name: str = "astrbot_download_file" - description: str = ( - "Transfer a file FROM the sandbox OUT to the host and optionally send it " - "to the user. Use this ONLY when the user asks to retrieve/export a file " - "that was created or modified inside the sandbox." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "remote_path": { - "type": "string", - "description": "Path of the file inside the sandbox to copy out to the host.", - }, - "also_send_to_user": { - "type": "boolean", - "description": "Whether to also send the downloaded file to the user via message. Defaults to true.", - }, - }, - "required": ["remote_path"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - remote_path: str, - also_send_to_user: bool = True, - ) -> ToolExecResult: - if permission_error := check_admin_permission(context, "File upload/download"): - return permission_error - sb = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - try: - name = os.path.basename(remote_path) - - local_path = os.path.join( - get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}" - ) - - # Download file from sandbox - await sb.download_file(remote_path, local_path) - logger.info(f"File {remote_path} downloaded from sandbox to {local_path}") - - if also_send_to_user: - try: - name = os.path.basename(local_path) - await context.context.event.send( - MessageChain(chain=[File(name=name, file=local_path)]) - ) - except Exception as e: - logger.error(f"Error sending file message: {e}") - - # remove - # try: - # os.remove(local_path) - # except Exception as e: - # logger.error(f"Error removing temp file {local_path}: {e}") - - return f"File downloaded successfully to {local_path} and sent to user." - - return f"File downloaded successfully to {local_path}" - except Exception as e: - logger.error(f"Error downloading file {remote_path}: {e}") - return f"Error downloading file: {str(e)}" diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 058cf61e54..6192e23afd 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -36,6 +36,7 @@ PlatformAdapterType, ) from astrbot.core.subagent_orchestrator import SubAgentOrchestrator +from astrbot.core.utils.astrbot_path import get_astrbot_system_tmp_path from ..exceptions import ProviderNotFoundError from .filter.command import CommandFilter @@ -232,6 +233,13 @@ async def tool_loop_agent( for k, v in kwargs.items() if k not in ["stream", "agent_hooks", "agent_context"] } + if request.func_tool and request.func_tool.get_tool("astrbot_file_read_tool"): + other_kwargs.setdefault( + "tool_result_overflow_dir", get_astrbot_system_tmp_path() + ) + other_kwargs.setdefault( + "read_tool", request.func_tool.get_tool("astrbot_file_read_tool") + ) await agent_runner.reset( provider=prov, diff --git a/astrbot/core/tools/computer_tools/__init__.py b/astrbot/core/tools/computer_tools/__init__.py new file mode 100644 index 0000000000..7e364ffd23 --- /dev/null +++ b/astrbot/core/tools/computer_tools/__init__.py @@ -0,0 +1,55 @@ +from .fs import ( + FileDownloadTool, + FileEditTool, + FileReadTool, + FileUploadTool, + FileWriteTool, + GrepTool, +) +from .python import LocalPythonTool, PythonTool +from .shell import ExecuteShellTool +from .shipyard_neo import ( + AnnotateExecutionTool, + BrowserBatchExecTool, + BrowserExecTool, + CreateSkillCandidateTool, + CreateSkillPayloadTool, + EvaluateSkillCandidateTool, + GetExecutionHistoryTool, + GetSkillPayloadTool, + ListSkillCandidatesTool, + ListSkillReleasesTool, + PromoteSkillCandidateTool, + RollbackSkillReleaseTool, + RunBrowserSkillTool, + SyncSkillReleaseTool, +) +from .util import check_admin_permission, normalize_umo_for_workspace + +__all__ = [ + "AnnotateExecutionTool", + "BrowserBatchExecTool", + "BrowserExecTool", + "CreateSkillCandidateTool", + "CreateSkillPayloadTool", + "EvaluateSkillCandidateTool", + "ExecuteShellTool", + "FileDownloadTool", + "FileEditTool", + "FileReadTool", + "FileUploadTool", + "FileWriteTool", + "GetExecutionHistoryTool", + "GetSkillPayloadTool", + "GrepTool", + "ListSkillCandidatesTool", + "ListSkillReleasesTool", + "LocalPythonTool", + "PromoteSkillCandidateTool", + "PythonTool", + "RollbackSkillReleaseTool", + "RunBrowserSkillTool", + "SyncSkillReleaseTool", + "normalize_umo_for_workspace", + "check_admin_permission", +] diff --git a/astrbot/core/tools/computer_tools/fs.py b/astrbot/core/tools/computer_tools/fs.py new file mode 100644 index 0000000000..8d3160ff5d --- /dev/null +++ b/astrbot/core/tools/computer_tools/fs.py @@ -0,0 +1,749 @@ +"""Filesystem tool audit. + +Tool exposure from the main agent: +- Local runtime exposes `astrbot_read_file_tool`, `astrbot_file_write_tool`, + `astrbot_file_edit_tool`, and `astrbot_grep_tool`. +- Sandbox runtime exposes `astrbot_upload_file`, `astrbot_download_file`, + `astrbot_read_file_tool`, `astrbot_file_write_tool`, + `astrbot_file_edit_tool`, and `astrbot_grep_tool`. + +Behavior when `provider_settings.computer_use_require_admin=True`: +- Admin + local: read/write/edit/grep are not path-restricted by this module; + access depends on the local runtime implementation and host OS permissions. + Upload and download tools are defined here, but `LocalBooter` does not + implement them and the main agent does not expose them in local mode. +- Member + local: read/write/edit/grep are restricted to `data/skills`, + `data/workspaces/{normalized_umo}`, and `/tmp/.astrbot`. Upload/download are + denied by `check_admin_permission` if invoked. +- Admin + sandbox: read/write/edit/grep are not path-restricted by this + module; + sandbox filesystem boundaries are enforced by the sandbox runtime. Upload and + download are allowed. +- Member + sandbox: read/write/edit/grep are also not path-restricted by this + module. Upload/download are denied by `check_admin_permission` if invoked. + +When `computer_use_require_admin=False`, member behavior in this module matches +admin behavior. + +Local path resolution rule: +- In local runtime, relative paths are resolved under + `data/workspaces/{normalized_umo}`. +- In sandbox runtime, relative paths are passed through unchanged. +""" + +import os +import uuid +from dataclasses import dataclass, field +from pathlib import Path + +from astrbot.api import FunctionTool, logger +from astrbot.api.event import MessageChain +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.computer.computer_client import get_booter +from astrbot.core.computer.file_read_utils import read_file_tool_result +from astrbot.core.message.components import File +from astrbot.core.utils.astrbot_path import ( + get_astrbot_skills_path, + get_astrbot_system_tmp_path, + get_astrbot_temp_path, +) + +from ..registry import builtin_tool +from . import util as computer_util +from .util import ( + check_admin_permission, + is_local_runtime, + normalize_umo_for_workspace, +) + +_COMPUTER_RUNTIME_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": ("local", "sandbox"), +} +_SANDBOX_RUNTIME_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": "sandbox", +} + + +def _restricted_env_path_labels(umo: str) -> list[str]: + """Labels for the allowed directories in a local(not sandbox) and restricted(not admin) environment""" + normalized_umo = normalize_umo_for_workspace(umo) + return [ + "data/skills", + f"data/workspaces/{normalized_umo}", + get_astrbot_system_tmp_path(), + get_astrbot_temp_path(), + ] + + +def get_astrbot_workspaces_path() -> str: + """Compatibility wrapper for tests and older module-level monkeypatches.""" + return computer_util.get_astrbot_workspaces_path() + + +def _workspace_root(umo: str) -> Path: + """Workspace root that follows both util-level and fs-level getter monkeypatches.""" + normalized_umo = normalize_umo_for_workspace(umo) + return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False) + + +def _read_allowed_roots(umo: str) -> tuple[Path, ...]: + """Non-admin users can only read files within these directories (and their subdirectories)""" + return ( + Path(get_astrbot_skills_path()).resolve(strict=False), + _workspace_root(umo), + Path(get_astrbot_system_tmp_path()).resolve(strict=False), + Path(get_astrbot_temp_path()).resolve(strict=False), + ) + + +def _is_restricted_env(context: ContextWrapper[AstrAgentContext]) -> bool: + if not is_local_runtime(context): + return False + cfg = context.context.context.get_config( + umo=context.context.event.unified_msg_origin + ) + provider_settings = cfg.get("provider_settings", {}) + require_admin = provider_settings.get("computer_use_require_admin", True) + return require_admin and context.context.event.role != "admin" + + +def _resolve_tool_path(path: str, *, local_env: bool, umo: str) -> str: + normalized_path = path.strip() + if not normalized_path: + return normalized_path + candidate = Path(normalized_path).expanduser() + if candidate.is_absolute(): + return str(candidate.resolve(strict=False)) + if local_env: + return str((_workspace_root(umo) / candidate).resolve(strict=False)) + return normalized_path + + +def _resolve_user_path(path: str, *, local_env: bool, umo: str) -> Path: + candidate = Path(path).expanduser() + if candidate.is_absolute(): + return candidate.resolve(strict=False) + if local_env: + return (_workspace_root(umo) / candidate).resolve(strict=False) + return (Path.cwd() / candidate).resolve(strict=False) + + +def _is_path_within_allowed_roots(path: str, umo: str) -> bool: + resolved = _resolve_user_path(path, local_env=True, umo=umo) + return any( + resolved == allowed_root or resolved.is_relative_to(allowed_root) + for allowed_root in _read_allowed_roots(umo) + ) + + +def _normalize_rw_path( + path: str, + *, + restricted: bool, + local_env: bool, + umo: str, +) -> str: + normalized_path = _resolve_tool_path(path, local_env=local_env, umo=umo) + if not normalized_path: + raise ValueError("`path` must be a non-empty string.") + if restricted and not _is_path_within_allowed_roots(normalized_path, umo): + allowed = ", ".join(_restricted_env_path_labels(umo)) + raise PermissionError( + "Read access is restricted for this user. " + f"Allowed directories: {allowed}. Blocked path: {normalized_path}." + ) + return normalized_path + + +def _decode_escaped_text(value: str) -> str: + """Decode common escaped control sequences used in tool arguments.""" + return ( + value.replace("\\r\\n", "\n") + .replace("\\n", "\n") + .replace("\\r", "\r") + .replace("\\t", "\t") + ) + + +@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG) +@dataclass +class FileReadTool(FunctionTool): + name: str = "astrbot_file_read_tool" + description: str = "read file content." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path of the file to read. If relative, will be in workspace root.", + }, + "offset": { + "type": "integer", + "description": "Optional line offset to start reading from. 0-based index.", + "minimum": 0, + }, + "limit": { + "type": "integer", + "description": "Optional maximum number of lines to read.", + "minimum": 1, + }, + }, + "required": ["path"], + } + ) + + def _validate_read_window( + self, + offset: int | None, + limit: int | None, + ) -> tuple[int | None, int | None]: + if offset is not None and offset < 0: + raise ValueError("`offset` must be greater than or equal to 0.") + if limit is not None and limit < 1: + raise ValueError("`limit` must be greater than or equal to 1.") + return offset, limit + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + path: str, + offset: int | None = None, + limit: int | None = None, + ) -> ToolExecResult: + local_env = is_local_runtime(context) + restricted = _is_restricted_env(context) + try: + normalized_path = ( + _normalize_rw_path( + path, + restricted=restricted, + local_env=local_env, + umo=context.context.event.unified_msg_origin, + ) + if local_env + else path.strip() + ) + if not normalized_path: + raise ValueError("`path` must be a non-empty string.") + offset, limit = self._validate_read_window(offset, limit) + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + return await read_file_tool_result( + sb, + local_mode=local_env, + path=normalized_path, + offset=offset, + limit=limit, + workspace_dir=( + str(_workspace_root(context.context.event.unified_msg_origin)) + if local_env + else None + ), + ) + except PermissionError as exc: + return f"Error: {exc}" + except Exception as exc: + logger.error(f"Error reading file: {exc}") + return f"Error reading file: {exc}" + + +@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG) +@dataclass +class FileWriteTool(FunctionTool): + name: str = "astrbot_file_write_tool" + description: str = "Write UTF-8 text content to a file." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path of the file to write. If relative, will be in workspace root.", + }, + "content": { + "type": "string", + "description": "The content to write to the file", + }, + }, + "required": ["path", "content"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + path: str, + content: str, + ) -> ToolExecResult: + local_env = is_local_runtime(context) + restricted = _is_restricted_env(context) + try: + normalized_path = ( + _normalize_rw_path( + path, + restricted=restricted, + local_env=local_env, + umo=context.context.event.unified_msg_origin, + ) + if local_env + else path.strip() + ) + if not normalized_path: + raise ValueError("`path` must be a non-empty string.") + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + result = await sb.fs.write_file( + path=normalized_path, + content=content, + mode="w", + encoding="utf-8", + ) + if not result.get("success", False): + error_detail = str(result.get("error", "") or "").strip() + return ( + "Error writing file: " + f"{error_detail or 'unknown filesystem write error'}" + ) + return f"File written successfully: {normalized_path}" + except PermissionError as exc: + return f"Error: {exc}" + except Exception as exc: + logger.error(f"Error writing file: {exc}") + return f"Error writing file: {exc}" + + +@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG) +@dataclass +class FileEditTool(FunctionTool): + name: str = "astrbot_file_edit_tool" + description: str = "Editing files." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path of the file to edit. If relative, will be in workspace root.", + }, + "old": { + "type": "string", + "description": "The exact old text to replace.", + }, + "new": { + "type": "string", + "description": "The replacement text.", + }, + "replace_all": { + "type": "boolean", + "description": "Whether to replace all matches. Defaults to false.", + }, + }, + "required": ["path", "old", "new"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + path: str, + old: str, + new: str, + replace_all: bool = False, + ) -> ToolExecResult: + umo = str(context.context.event.unified_msg_origin) + local_env = is_local_runtime(context) + restricted = _is_restricted_env(context) + try: + normalized_path = ( + _normalize_rw_path( + path, + restricted=restricted, + local_env=local_env, + umo=umo, + ) + if local_env + else path.strip() + ) + if not normalized_path: + raise ValueError("`path` must be a non-empty string.") + normalized_old = _decode_escaped_text(old) + normalized_new = _decode_escaped_text(new) + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + result = await sb.fs.edit_file( + path=normalized_path, + old_string=normalized_old, + new_string=normalized_new, + replace_all=replace_all, + encoding="utf-8", + ) + if not result.get("success", False): + error_detail = str(result.get("error", "") or "").strip() + return ( + "Error editing file: " + f"{error_detail or 'unknown filesystem edit error'}" + ) + replacements = int(result.get("replacements", 0) or 0) + mode_text = "all matches" if replace_all else "first match" + return ( + f"Edited {normalized_path}. " + f"Replaced {replacements} occurrence(s) using {mode_text} mode." + ) + except PermissionError as exc: + return f"Error: {exc}" + except Exception as exc: + logger.error(f"Error editing file: {exc}") + return f"Error editing file: {exc}" + + +@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG) +@dataclass +class GrepTool(FunctionTool): + name: str = "astrbot_grep_tool" + description: str = "Search and read file contents using ripgrep." + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "The expression pattern to search for in file contents.", + }, + "path": { + "type": "string", + "description": "File or directory to search in (rg PATH). If relative, will be in workspace root.", + }, + "glob": { + "type": "string", + "description": "Optional glob filter such as `*.py`, `*.{ts,tsx}`.", + }, + "-A": { + "type": "integer", + "description": "Number of trailing context lines to include after each match.", + "minimum": 0, + }, + "-B": { + "type": "integer", + "description": "Number of leading context lines to include before each match.", + "minimum": 0, + }, + "-C": { + "type": "integer", + "description": "Number of leading and trailing context lines to include around each match.", + "minimum": 0, + }, + "result_limit": { + "type": "integer", + "description": "Maximum number of result groups returned by the tool. Defaults to 100.", + "minimum": 1, + }, + }, + "required": ["pattern"], + } + ) + + def _resolve_context_options( + self, + after_context: int | None, + before_context: int | None, + context: int | None, + ) -> tuple[int | None, int | None]: + if context is not None and context < 0: + raise ValueError("`-C` must be greater than or equal to 0.") + if after_context is not None and after_context < 0: + raise ValueError("`-A` must be greater than or equal to 0.") + if before_context is not None and before_context < 0: + raise ValueError("`-B` must be greater than or equal to 0.") + + resolved_after = context if after_context is None else after_context + resolved_before = context if before_context is None else before_context + return resolved_after, resolved_before + + def _split_output_groups(self, output: str, *, has_context: bool) -> list[str]: + if not output.strip(): + return [] + + if not has_context: + return [f"{line}\n" for line in output.splitlines() if line.strip()] + + groups: list[str] = [] + current: list[str] = [] + + for line in output.splitlines(keepends=True): + if line.strip() == "--": + if current: + groups.append("".join(current)) + current = [] + continue + if not line.strip(): + continue + current.append(line) + + if current: + groups.append("".join(current)) + return groups + + def _apply_result_limit( + self, + output: str, + *, + result_limit: int, + has_context: bool, + ) -> str: + if result_limit < 1: + raise ValueError("`result_limit` must be greater than or equal to 1.") + + groups = self._split_output_groups(output, has_context=has_context) + if len(groups) <= result_limit: + return output if output.strip() else "No matches found." + + limited_output = "".join(groups[:result_limit]).rstrip() + return f"{limited_output}\n\n[Truncated to first {result_limit} result groups.]" + + def _normalize_search_paths( + self, + path: str | None, + *, + restricted: bool, + local_env: bool, + umo: str, + ) -> list[str]: + normalized = ( + [_resolve_tool_path(path, local_env=local_env, umo=umo)] if path else [] + ) + if not normalized: + if restricted: + return [str(root) for root in _read_allowed_roots(umo)] + if local_env: + return [str(_workspace_root(umo))] + return ["."] + + if restricted: + disallowed = [ + path + for path in normalized + if not _is_path_within_allowed_roots(path, umo) + ] + if disallowed: + allowed = ", ".join(_restricted_env_path_labels(umo)) + blocked = ", ".join(disallowed) + raise PermissionError( + "Read access is restricted for this user. " + f"Allowed directories: {allowed}. Blocked paths: {blocked}." + ) + + return normalized + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + pattern: str, + path: str | None = None, + glob: str | None = None, + result_limit: int = 100, + **kwargs, + ) -> ToolExecResult: + normalized_pattern = pattern.strip() + if not normalized_pattern: + return "Error: `pattern` must be a non-empty string." + + local_env = is_local_runtime(context) + restricted = _is_restricted_env(context) + try: + search_paths = ( + self._normalize_search_paths( + path, + restricted=restricted, + local_env=local_env, + umo=context.context.event.unified_msg_origin, + ) + if local_env + else ([path.strip()] if path and path.strip() else ["."]) + ) + after_context, before_context = self._resolve_context_options( + kwargs.get("-A"), + kwargs.get("-B"), + kwargs.get("-C"), + ) + has_context = (after_context or 0) > 0 or (before_context or 0) > 0 + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + contents: list[str] = [] + for search_path in search_paths: + result = await sb.fs.search_files( + pattern=normalized_pattern, + path=search_path, + glob=glob, + after_context=after_context, + before_context=before_context, + ) + if not result.get("success", False): + error_detail = str(result.get("error", "") or "").strip() + logger.error("GrepTool search failed: %s", error_detail) + return ( + "Error searching files: " + f"{error_detail or 'unknown filesystem search error'}" + ) + content = str(result.get("content", "") or "") + if content: + contents.append(content) + + return self._apply_result_limit( + "".join(contents), + result_limit=result_limit, + has_context=has_context, + ) + except PermissionError as exc: + return f"Error: {exc}" + except Exception as exc: + logger.error(f"Error searching files: {exc}") + return f"Error searching files: {exc}" + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class FileUploadTool(FunctionTool): + name: str = "astrbot_upload_file" + description: str = ( + "Transfer a file FROM the host machine INTO the sandbox so that sandbox " + "code can access it. Use this when the user sends/attaches a file and you " + "need to process it inside the sandbox. The local_path must point to an " + "existing file on the host filesystem." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "local_path": { + "type": "string", + "description": "Absolute path to the file on the host filesystem that will be copied into the sandbox.", + }, + # "remote_path": { + # "type": "string", + # "description": "The filename to use in the sandbox. If not provided, file will be saved to the working directory with the same name as the local file.", + # }, + }, + "required": ["local_path"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + local_path: str, + ) -> str | None: + if permission_error := check_admin_permission(context, "File upload/download"): + return permission_error + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + try: + # Check if file exists + if not os.path.exists(local_path): + return f"Error: File does not exist: {local_path}" + + if not os.path.isfile(local_path): + return f"Error: Path is not a file: {local_path}" + + # Use basename if sandbox_filename is not provided + remote_path = os.path.basename(local_path) + + # Upload file to sandbox + result = await sb.upload_file(local_path, remote_path) + logger.debug(f"Upload result: {result}") + success = result.get("success", False) + + if not success: + return f"Error uploading file: {result.get('message', 'Unknown error')}" + + file_path = result.get("file_path", "") + logger.info(f"File {local_path} uploaded to sandbox at {file_path}") + + return f"File uploaded successfully to {file_path}" + except Exception as e: + logger.error(f"Error uploading file {local_path}: {e}") + return f"Error uploading file: {str(e)}" + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class FileDownloadTool(FunctionTool): + name: str = "astrbot_download_file" + description: str = ( + "Transfer a file FROM the sandbox OUT to the host and optionally send it " + "to the user. Use this ONLY when the user asks to retrieve/export a file " + "that was created or modified inside the sandbox." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "remote_path": { + "type": "string", + "description": "Path of the file inside the sandbox to copy out to the host.", + }, + "also_send_to_user": { + "type": "boolean", + "description": "Whether to also send the downloaded file to the user via message. Defaults to true.", + }, + }, + "required": ["remote_path"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + remote_path: str, + also_send_to_user: bool = True, + ) -> ToolExecResult: + if permission_error := check_admin_permission(context, "File upload/download"): + return permission_error + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) + try: + name = os.path.basename(remote_path) + + local_path = os.path.join( + get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}" + ) + + # Download file from sandbox + await sb.download_file(remote_path, local_path) + logger.info(f"File {remote_path} downloaded from sandbox to {local_path}") + + if also_send_to_user: + try: + name = os.path.basename(local_path) + await context.context.event.send( + MessageChain(chain=[File(name=name, file=local_path)]) + ) + except Exception as e: + logger.error(f"Error sending file message: {e}") + + # remove + # try: + # os.remove(local_path) + # except Exception as e: + # logger.error(f"Error removing temp file {local_path}: {e}") + + return f"File downloaded successfully to {local_path} and sent to user." + + return f"File downloaded successfully to {local_path}" + except Exception as e: + logger.error(f"Error downloading file {remote_path}: {e}") + return f"Error downloading file: {str(e)}" diff --git a/astrbot/core/computer/tools/python.py b/astrbot/core/tools/computer_tools/python.py similarity index 90% rename from astrbot/core/computer/tools/python.py rename to astrbot/core/tools/computer_tools/python.py index bf9aaa14e5..e0bb6c9de6 100644 --- a/astrbot/core/computer/tools/python.py +++ b/astrbot/core/tools/computer_tools/python.py @@ -8,10 +8,18 @@ from astrbot.core.agent.tool import ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext, AstrMessageEvent from astrbot.core.computer.computer_client import get_booter, get_local_booter -from astrbot.core.computer.tools.permissions import check_admin_permission from astrbot.core.message.message_event_result import MessageChain +from ..registry import builtin_tool +from .util import check_admin_permission + _OS_NAME = platform.system() +_SANDBOX_PYTHON_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": "sandbox", +} +_LOCAL_PYTHON_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": "local", +} param_schema = { "type": "object", @@ -61,6 +69,7 @@ async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult return resp +@builtin_tool(config=_SANDBOX_PYTHON_TOOL_CONFIG) @dataclass class PythonTool(FunctionTool): name: str = "astrbot_execute_ipython" @@ -83,6 +92,7 @@ async def call( return f"Error executing code: {str(e)}" +@builtin_tool(config=_LOCAL_PYTHON_TOOL_CONFIG) @dataclass class LocalPythonTool(FunctionTool): name: str = "astrbot_execute_python" diff --git a/astrbot/core/computer/tools/shell.py b/astrbot/core/tools/computer_tools/shell.py similarity index 65% rename from astrbot/core/computer/tools/shell.py rename to astrbot/core/tools/computer_tools/shell.py index b5009d30fd..2d3997387c 100644 --- a/astrbot/core/computer/tools/shell.py +++ b/astrbot/core/tools/computer_tools/shell.py @@ -5,11 +5,17 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.computer.computer_client import get_booter -from ..computer_client import get_booter, get_local_booter -from .permissions import check_admin_permission +from ..registry import builtin_tool +from .util import check_admin_permission, is_local_runtime, workspace_root +_COMPUTER_RUNTIME_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": ("local", "sandbox"), +} + +@builtin_tool(config=_COMPUTER_RUNTIME_TOOL_CONFIG) @dataclass class ExecuteShellTool(FunctionTool): name: str = "astrbot_execute_shell" @@ -38,8 +44,6 @@ class ExecuteShellTool(FunctionTool): } ) - is_local: bool = False - async def call( self, context: ContextWrapper[AstrAgentContext], @@ -50,15 +54,25 @@ async def call( if permission_error := check_admin_permission(context, "Shell execution"): return permission_error - if self.is_local: - sb = get_local_booter() - else: - sb = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) + sb = await get_booter( + context.context.context, + context.context.event.unified_msg_origin, + ) try: - result = await sb.shell.exec(command, background=background, env=env) + cwd: str | None = None + if is_local_runtime(context): + current_workspace_root = workspace_root( + context.context.event.unified_msg_origin + ) + current_workspace_root.mkdir(parents=True, exist_ok=True) + cwd = str(current_workspace_root) + + result = await sb.shell.exec( + command, + cwd=cwd, + background=background, + env=env, + ) return json.dumps(result) except Exception as e: return f"Error executing command: {str(e)}" diff --git a/astrbot/core/computer/tools/__init__.py b/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py similarity index 77% rename from astrbot/core/computer/tools/__init__.py rename to astrbot/core/tools/computer_tools/shipyard_neo/__init__.py index 598abbb6ea..9228c86354 100644 --- a/astrbot/core/computer/tools/__init__.py +++ b/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py @@ -1,5 +1,4 @@ from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool -from .fs import FileDownloadTool, FileUploadTool from .neo_skills import ( AnnotateExecutionTool, CreateSkillCandidateTool, @@ -13,27 +12,20 @@ RollbackSkillReleaseTool, SyncSkillReleaseTool, ) -from .python import LocalPythonTool, PythonTool -from .shell import ExecuteShellTool __all__ = [ - "BrowserExecTool", - "BrowserBatchExecTool", - "RunBrowserSkillTool", - "GetExecutionHistoryTool", "AnnotateExecutionTool", + "BrowserBatchExecTool", + "BrowserExecTool", + "CreateSkillCandidateTool", "CreateSkillPayloadTool", + "EvaluateSkillCandidateTool", + "GetExecutionHistoryTool", "GetSkillPayloadTool", - "CreateSkillCandidateTool", "ListSkillCandidatesTool", - "EvaluateSkillCandidateTool", - "PromoteSkillCandidateTool", "ListSkillReleasesTool", + "PromoteSkillCandidateTool", "RollbackSkillReleaseTool", + "RunBrowserSkillTool", "SyncSkillReleaseTool", - "FileUploadTool", - "PythonTool", - "LocalPythonTool", - "ExecuteShellTool", - "FileDownloadTool", ] diff --git a/astrbot/core/computer/tools/browser.py b/astrbot/core/tools/computer_tools/shipyard_neo/browser.py similarity index 93% rename from astrbot/core/computer/tools/browser.py rename to astrbot/core/tools/computer_tools/shipyard_neo/browser.py index cd8484acb6..b4b7f4fd06 100644 --- a/astrbot/core/computer/tools/browser.py +++ b/astrbot/core/tools/computer_tools/shipyard_neo/browser.py @@ -6,9 +6,14 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.computer.computer_client import get_booter +from astrbot.core.tools.computer_tools.util import check_admin_permission +from astrbot.core.tools.registry import builtin_tool -from ..computer_client import get_booter -from .permissions import check_admin_permission +_SHIPYARD_NEO_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "shipyard_neo", +} def _to_json(data: Any) -> str: @@ -29,6 +34,7 @@ async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> A return browser +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class BrowserExecTool(FunctionTool): name: str = "astrbot_execute_browser" @@ -86,6 +92,7 @@ async def call( return f"Error executing browser command: {str(e)}" +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class BrowserBatchExecTool(FunctionTool): name: str = "astrbot_execute_browser_batch" @@ -150,6 +157,7 @@ async def call( return f"Error executing browser batch command: {str(e)}" +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class RunBrowserSkillTool(FunctionTool): name: str = "astrbot_run_browser_skill" diff --git a/astrbot/core/computer/tools/neo_skills.py b/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py similarity index 95% rename from astrbot/core/computer/tools/neo_skills.py rename to astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py index 327f144722..e2c4f59093 100644 --- a/astrbot/core/computer/tools/neo_skills.py +++ b/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py @@ -7,10 +7,15 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import ToolExecResult from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.computer.computer_client import get_booter from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager +from astrbot.core.tools.computer_tools.util import check_admin_permission +from astrbot.core.tools.registry import builtin_tool -from ..computer_client import get_booter -from .permissions import check_admin_permission +_SHIPYARD_NEO_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "shipyard_neo", +} def _to_jsonable(model_like: Any) -> Any: @@ -64,6 +69,7 @@ async def _run( return f"{self.error_prefix} {error_action}: {str(e)}" +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class GetExecutionHistoryTool(NeoSkillToolBase): name: str = "astrbot_get_execution_history" @@ -110,6 +116,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class AnnotateExecutionTool(NeoSkillToolBase): name: str = "astrbot_annotate_execution" @@ -147,6 +154,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class CreateSkillPayloadTool(NeoSkillToolBase): name: str = "astrbot_create_skill_payload" @@ -194,6 +202,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class GetSkillPayloadTool(NeoSkillToolBase): name: str = "astrbot_get_skill_payload" @@ -220,6 +229,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class CreateSkillCandidateTool(NeoSkillToolBase): name: str = "astrbot_create_skill_candidate" @@ -273,6 +283,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class ListSkillCandidatesTool(NeoSkillToolBase): name: str = "astrbot_list_skill_candidates" @@ -310,6 +321,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class EvaluateSkillCandidateTool(NeoSkillToolBase): name: str = "astrbot_evaluate_skill_candidate" @@ -350,6 +362,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class PromoteSkillCandidateTool(NeoSkillToolBase): name: str = "astrbot_promote_skill_candidate" @@ -420,6 +433,7 @@ async def call( return f"Error promoting skill candidate: {str(e)}" +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class ListSkillReleasesTool(NeoSkillToolBase): name: str = "astrbot_list_skill_releases" @@ -460,6 +474,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class RollbackSkillReleaseTool(NeoSkillToolBase): name: str = "astrbot_rollback_skill_release" @@ -486,6 +501,7 @@ async def call( ) +@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) @dataclass class SyncSkillReleaseTool(NeoSkillToolBase): name: str = "astrbot_sync_skill_release" diff --git a/astrbot/core/computer/tools/permissions.py b/astrbot/core/tools/computer_tools/util.py similarity index 52% rename from astrbot/core/computer/tools/permissions.py rename to astrbot/core/tools/computer_tools/util.py index 489f485f9d..a3930b4c6a 100644 --- a/astrbot/core/computer/tools/permissions.py +++ b/astrbot/core/tools/computer_tools/util.py @@ -1,5 +1,29 @@ +import re +from pathlib import Path + from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.utils.astrbot_path import get_astrbot_workspaces_path + + +def normalize_umo_for_workspace(umo: str) -> str: + normalized = re.sub(r"[^A-Za-z0-9._-]+", "_", umo.strip()) + return normalized or "unknown" + + +def workspace_root(umo: str) -> Path: + """Root directory for relative paths in local runtime""" + normalized_umo = normalize_umo_for_workspace(umo) + return (Path(get_astrbot_workspaces_path()) / normalized_umo).resolve(strict=False) + + +def is_local_runtime(context: ContextWrapper[AstrAgentContext]) -> bool: + cfg = context.context.context.get_config( + umo=context.context.event.unified_msg_origin + ) + provider_settings = cfg.get("provider_settings", {}) + runtime = str(provider_settings.get("computer_use_runtime", "local")) + return runtime == "local" def check_admin_permission( diff --git a/astrbot/core/tools/cron_tools.py b/astrbot/core/tools/cron_tools.py index ccc93172d5..d0b5638168 100644 --- a/astrbot/core/tools/cron_tools.py +++ b/astrbot/core/tools/cron_tools.py @@ -9,6 +9,10 @@ from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.tools.registry import builtin_tool +_CRON_TOOL_CONFIG = { + "provider_settings.proactive_capability.add_cron_tools": True, +} + def _extract_job_session(job: Any) -> str | None: payload = getattr(job, "payload", None) @@ -24,7 +28,7 @@ def _parse_run_at(run_at: Any) -> datetime | None: return datetime.fromisoformat(str(run_at)) -@builtin_tool +@builtin_tool(config=_CRON_TOOL_CONFIG) @dataclass class FutureTaskTool(FunctionTool[AstrAgentContext]): name: str = "future_task" diff --git a/astrbot/core/tools/knowledge_base_tools.py b/astrbot/core/tools/knowledge_base_tools.py index e27a883d4a..e082fd4253 100644 --- a/astrbot/core/tools/knowledge_base_tools.py +++ b/astrbot/core/tools/knowledge_base_tools.py @@ -9,6 +9,10 @@ from astrbot.core.star.context import Context from astrbot.core.tools.registry import builtin_tool +_KNOWLEDGE_BASE_TOOL_CONFIG = { + "kb_agentic_mode": True, +} + def check_all_kb(kb_list: list[KBHelper | None]) -> bool: """检查是否所有的知识库都为空""" @@ -83,7 +87,7 @@ async def retrieve_knowledge_base( return None -@builtin_tool +@builtin_tool(config=_KNOWLEDGE_BASE_TOOL_CONFIG) @dataclass class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): name: str = "astr_kb_search" diff --git a/astrbot/core/tools/registry.py b/astrbot/core/tools/registry.py index eaca4af144..c3b10d2295 100644 --- a/astrbot/core/tools/registry.py +++ b/astrbot/core/tools/registry.py @@ -1,13 +1,16 @@ from __future__ import annotations +from collections.abc import Callable +from dataclasses import dataclass from importlib import import_module -from typing import TypeVar +from typing import Any, TypeVar from astrbot.core.agent.tool import FunctionTool TFunctionTool = TypeVar("TFunctionTool", bound=type[FunctionTool]) _BUILTIN_TOOL_MODULES = ( + "astrbot.core.tools.computer_tools", "astrbot.core.tools.cron_tools", "astrbot.core.tools.knowledge_base_tools", "astrbot.core.tools.message_tools", @@ -17,6 +20,182 @@ _builtin_tool_classes_by_name: dict[str, type[FunctionTool]] = {} _builtin_tool_names_by_class: dict[type[FunctionTool], str] = {} _builtin_tools_loaded = False +_MISSING = object() + + +@dataclass(frozen=True) +class BuiltinToolConfigCondition: + key: str + operator: str + expected: Any = None + message: str | None = None + + def evaluate(self, config: dict[str, Any]) -> dict[str, Any]: + actual = _get_config_value(config, self.key) + + if self.operator == "equals": + matched = actual == self.expected + elif self.operator == "in": + expected_values = tuple(self.expected or ()) + matched = actual in expected_values + elif self.operator == "truthy": + matched = bool(actual) + elif self.operator == "custom": + matched = bool(self.expected) + else: + raise ValueError( + f"Unsupported builtin tool config operator: {self.operator}" + ) + + return { + "key": self.key, + "operator": self.operator, + "expected": _json_safe(self.expected), + "actual": _json_safe(None if actual is _MISSING else actual), + "matched": matched, + "message": self.message, + } + + +@dataclass(frozen=True) +class BuiltinToolConfigRule: + conditions: tuple[BuiltinToolConfigCondition, ...] = () + evaluator: Callable[[dict[str, Any]], list[dict[str, Any]]] | None = None + + def evaluate(self, config: dict[str, Any]) -> list[dict[str, Any]]: + if self.evaluator is not None: + return self.evaluator(config) + return [condition.evaluate(config) for condition in self.conditions] + + +def _get_config_value(config: dict[str, Any], key_path: str) -> Any: + current: Any = config + for segment in key_path.split("."): + if not isinstance(current, dict) or segment not in current: + return _MISSING + current = current[segment] + return current + + +def _json_safe(value: Any) -> Any: + if isinstance(value, tuple): + return [_json_safe(item) for item in value] + if isinstance(value, list): + return [_json_safe(item) for item in value] + if isinstance(value, dict): + return {key: _json_safe(val) for key, val in value.items()} + return value + + +def _equals(key: str, expected: Any) -> BuiltinToolConfigCondition: + return BuiltinToolConfigCondition(key=key, operator="equals", expected=expected) + + +def _in(key: str, expected: tuple[Any, ...]) -> BuiltinToolConfigCondition: + return BuiltinToolConfigCondition(key=key, operator="in", expected=expected) + + +def _custom_condition(key: str, *, matched: bool, message: str) -> dict[str, Any]: + return { + "key": key, + "operator": "custom", + "expected": None, + "actual": None, + "matched": matched, + "message": message, + } + + +def _build_rule_from_config_map( + config_map: dict[str, Any], +) -> BuiltinToolConfigRule: + conditions: list[BuiltinToolConfigCondition] = [] + for key, expected in config_map.items(): + if isinstance(expected, tuple): + conditions.append(_in(key, expected)) + else: + conditions.append(_equals(key, expected)) + return BuiltinToolConfigRule(conditions=tuple(conditions)) + + +def _evaluate_send_message_tool(config: dict[str, Any]) -> list[dict[str, Any]]: + platform_configs = config.get("platform", []) + if not isinstance(platform_configs, list): + return [ + _custom_condition( + "platform", + matched=False, + message="No enabled platform in this config supports proactive messaging.", + ) + ] + + for platform_cfg in platform_configs: + if not isinstance(platform_cfg, dict): + continue + if platform_cfg.get("enable", False) is False: + continue + + platform_type = str(platform_cfg.get("type", "")).strip() + platform_id = str(platform_cfg.get("id", "")).strip() or platform_type + if not platform_type: + continue + + if platform_type in {"wecom", "weixin_official_account"}: + continue + + if platform_type == "wecom_ai_bot": + webhook = str(platform_cfg.get("msg_push_webhook_url", "")).strip() + if not webhook: + continue + return [ + _custom_condition( + "platform[].type", + matched=True, + message=( + f"Enabled platform `{platform_id}` uses `wecom_ai_bot`, which supports proactive messaging " + "when `platform[].msg_push_webhook_url` is configured." + ), + ), + BuiltinToolConfigCondition( + key="platform[].msg_push_webhook_url", + operator="truthy", + ).evaluate({"platform[]": {"msg_push_webhook_url": webhook}}), + ] + + return [ + _custom_condition( + "platform[].type", + matched=True, + message=( + f"Enabled platform `{platform_id}` (`{platform_type}`) supports proactive messaging." + ), + ) + ] + + return [ + _custom_condition( + "platform", + matched=False, + message="No enabled platform in this config supports proactive messaging.", + ) + ] + + +_BUILTIN_TOOL_CONFIG_RULES: dict[str, BuiltinToolConfigRule] = {} + + +def _register_builtin_tool_config_rule( + tool_names: tuple[str, ...], + rule: BuiltinToolConfigRule, +) -> None: + for tool_name in tool_names: + _BUILTIN_TOOL_CONFIG_RULES[tool_name] = rule + + +_register_builtin_tool_config_rule( + ("send_message_to_user",), + BuiltinToolConfigRule(evaluator=_evaluate_send_message_tool), +) def _resolve_builtin_tool_name(tool_cls: type[FunctionTool]) -> str: @@ -34,18 +213,29 @@ def _resolve_builtin_tool_name(tool_cls: type[FunctionTool]) -> str: ) -def builtin_tool(tool_cls: TFunctionTool) -> TFunctionTool: - tool_name = _resolve_builtin_tool_name(tool_cls) - existing = _builtin_tool_classes_by_name.get(tool_name) - if existing is not None and existing is not tool_cls: - raise ValueError( - f"Builtin tool name conflict detected: {tool_name} is already registered by " - f"{existing.__module__}.{existing.__name__}.", - ) +def builtin_tool( + tool_cls: TFunctionTool | None = None, + *, + config: dict[str, Any] | None = None, +) -> TFunctionTool | Callable[[TFunctionTool], TFunctionTool]: + def _register(cls: TFunctionTool) -> TFunctionTool: + tool_name = _resolve_builtin_tool_name(cls) + existing = _builtin_tool_classes_by_name.get(tool_name) + if existing is not None and existing is not cls: + raise ValueError( + f"Builtin tool name conflict detected: {tool_name} is already registered by " + f"{existing.__module__}.{existing.__name__}.", + ) - _builtin_tool_classes_by_name[tool_name] = tool_cls - _builtin_tool_names_by_class[tool_cls] = tool_name - return tool_cls + _builtin_tool_classes_by_name[tool_name] = cls + _builtin_tool_names_by_class[cls] = tool_name + if config is not None: + _BUILTIN_TOOL_CONFIG_RULES[tool_name] = _build_rule_from_config_map(config) + return cls + + if tool_cls is None: + return _register + return _register(tool_cls) def ensure_builtin_tools_loaded() -> None: @@ -74,9 +264,64 @@ def iter_builtin_tool_classes() -> tuple[type[FunctionTool], ...]: return tuple(_builtin_tool_classes_by_name.values()) +def get_builtin_tool_config_rule(name: str) -> BuiltinToolConfigRule | None: + ensure_builtin_tools_loaded() + return _BUILTIN_TOOL_CONFIG_RULES.get(name) + + +def get_builtin_tool_config_statuses( + tool_name: str, + config_entries: list[dict[str, Any]], +) -> list[dict[str, Any]]: + rule = get_builtin_tool_config_rule(tool_name) + if rule is None: + return [] + + statuses: list[dict[str, Any]] = [] + for entry in config_entries: + config = entry.get("config") + if not isinstance(config, dict): + continue + + conditions = rule.evaluate(config) + enabled = bool(conditions) and all( + bool(condition.get("matched")) for condition in conditions + ) + statuses.append( + { + "conf_id": entry.get("conf_id"), + "conf_name": entry.get("conf_name"), + "enabled": enabled, + "matched_conditions": [ + condition for condition in conditions if condition.get("matched") + ], + "failed_conditions": [ + condition + for condition in conditions + if not condition.get("matched") + ], + } + ) + return statuses + + +def get_builtin_tool_config_tags( + tool_name: str, + config_entries: list[dict[str, Any]], +) -> list[dict[str, Any]]: + return [ + status + for status in get_builtin_tool_config_statuses(tool_name, config_entries) + if status["enabled"] + ] + + __all__ = [ "builtin_tool", "ensure_builtin_tools_loaded", + "get_builtin_tool_config_rule", + "get_builtin_tool_config_statuses", + "get_builtin_tool_config_tags", "get_builtin_tool_class", "get_builtin_tool_name", "iter_builtin_tool_classes", diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index 5ca8c3e08e..eacee95114 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -20,6 +20,22 @@ "web_search_bocha", "web_search_brave", ] +_TAVILY_WEB_SEARCH_TOOL_CONFIG = { + "provider_settings.web_search": True, + "provider_settings.websearch_provider": "tavily", +} +_BOCHA_WEB_SEARCH_TOOL_CONFIG = { + "provider_settings.web_search": True, + "provider_settings.websearch_provider": "bocha", +} +_BRAVE_WEB_SEARCH_TOOL_CONFIG = { + "provider_settings.web_search": True, + "provider_settings.websearch_provider": "brave", +} +_BAIDU_WEB_SEARCH_TOOL_CONFIG = { + "provider_settings.web_search": True, + "provider_settings.websearch_provider": "baidu_ai_search", +} @std_dataclass @@ -276,7 +292,7 @@ async def _baidu_search( ] -@builtin_tool +@builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class TavilyWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_tavily" @@ -359,7 +375,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) -@builtin_tool +@builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class TavilyExtractWebPageTool(FunctionTool[AstrAgentContext]): name: str = "tavily_extract_web_page" @@ -406,7 +422,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return ret or "Error: Tavily web searcher does not return any results." -@builtin_tool +@builtin_tool(config=_BOCHA_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class BochaWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_bocha" @@ -470,7 +486,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) -@builtin_tool +@builtin_tool(config=_BRAVE_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class BraveWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_brave" @@ -528,7 +544,7 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) -@builtin_tool +@builtin_tool(config=_BAIDU_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class BaiduWebSearchTool(FunctionTool[AstrAgentContext]): name: str = "web_search_baidu" diff --git a/astrbot/core/utils/astrbot_path.py b/astrbot/core/utils/astrbot_path.py index 987ce110a5..c7771c1a64 100644 --- a/astrbot/core/utils/astrbot_path.py +++ b/astrbot/core/utils/astrbot_path.py @@ -1,32 +1,33 @@ -"""Astrbot统一路径获取 - -项目路径:固定为源码所在路径 -根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定 -数据目录路径:固定为根目录下的 data 目录 -配置文件路径:固定为数据目录下的 config 目录 -插件目录路径:固定为数据目录下的 plugins 目录 -插件数据目录路径:固定为数据目录下的 plugin_data 目录 -T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录 -WebChat 数据目录路径:固定为数据目录下的 webchat 目录 -临时文件目录路径:固定为数据目录下的 temp 目录 -Skills 目录路径:固定为数据目录下的 skills 目录 -第三方依赖目录路径:固定为数据目录下的 site-packages 目录 +"""Centralized AstrBot path helpers. + +Project path: +- Fixed to the source tree location. + +Root path: +- Defaults to the current working directory. +- Can be overridden with the ``ASTRBOT_ROOT`` environment variable. + +Data subdirectories: +- Most runtime data lives under ``/data``. +- A few tool-runtime files intentionally live under the system temporary + directory as ``.astrbot``. """ import os +import tempfile from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime def get_astrbot_path() -> str: - """获取Astrbot项目路径""" + """Return the AstrBot project source path.""" return os.path.realpath( os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../"), ) def get_astrbot_root() -> str: - """获取Astrbot根目录路径""" + """Return the AstrBot root directory.""" if path := os.environ.get("ASTRBOT_ROOT"): return os.path.realpath(path) if is_packaged_desktop_runtime(): @@ -35,55 +36,65 @@ def get_astrbot_root() -> str: def get_astrbot_data_path() -> str: - """获取Astrbot数据目录路径""" + """Return the AstrBot data directory path.""" return os.path.realpath(os.path.join(get_astrbot_root(), "data")) def get_astrbot_config_path() -> str: - """获取Astrbot配置文件路径""" + """Return the AstrBot config directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "config")) def get_astrbot_plugin_path() -> str: - """获取Astrbot插件目录路径""" + """Return the AstrBot plugin directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins")) def get_astrbot_plugin_data_path() -> str: - """获取Astrbot插件数据目录路径""" + """Return the AstrBot plugin data directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data")) def get_astrbot_t2i_templates_path() -> str: - """获取Astrbot T2I 模板目录路径""" + """Return the AstrBot T2I templates directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates")) def get_astrbot_webchat_path() -> str: - """获取Astrbot WebChat 数据目录路径""" + """Return the AstrBot WebChat data directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat")) def get_astrbot_temp_path() -> str: - """获取Astrbot临时文件目录路径""" + """Return the AstrBot temporary data directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp")) def get_astrbot_skills_path() -> str: - """获取Astrbot Skills 目录路径""" + """Return the AstrBot skills directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "skills")) +def get_astrbot_workspaces_path() -> str: + """Return the AstrBot workspaces directory path.""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "workspaces")) + + +def get_astrbot_system_tmp_path() -> str: + """Return the shared system temporary directory used by local tools.""" + return os.path.realpath(os.path.join(tempfile.gettempdir(), ".astrbot")) + + def get_astrbot_site_packages_path() -> str: - """获取Astrbot第三方依赖目录路径""" + """Return the AstrBot third-party site-packages directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "site-packages")) def get_astrbot_knowledge_base_path() -> str: - """获取Astrbot知识库根目录路径""" + """Return the AstrBot knowledge base root path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base")) def get_astrbot_backups_path() -> str: - """获取Astrbot备份目录路径""" + """Return the AstrBot backups directory path.""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups")) diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 33b74deffc..7273c51761 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -6,6 +6,7 @@ from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star import star_map +from astrbot.core.tools.registry import get_builtin_tool_config_statuses from .route import Response, Route, RouteContext @@ -434,13 +435,36 @@ async def get_tool_list(self): if tool.name not in existing_names: tools.append(tool) + conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list() + conf_name_map = {conf["id"]: conf["name"] for conf in conf_list} + config_entries = [] + for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items(): + config_entries.append( + { + "conf_id": conf_id, + "conf_name": conf_name_map.get(conf_id, conf_id), + "config": conf, + } + ) + tools_dict = [] for tool in tools: readonly = False + builtin_config_statuses = [] + builtin_config_tags = [] if self.tool_mgr.is_builtin_tool(tool.name): origin = "builtin" origin_name = "AstrBot Core" readonly = True + builtin_config_statuses = get_builtin_tool_config_statuses( + tool.name, + config_entries, + ) + builtin_config_tags = [ + status + for status in builtin_config_statuses + if status["enabled"] + ] elif isinstance(tool, MCPTool): origin = "mcp" origin_name = tool.mcp_server_name @@ -462,6 +486,8 @@ async def get_tool_list(self): "origin": origin, "origin_name": origin_name, "readonly": readonly, + "builtin_config_statuses": builtin_config_statuses, + "builtin_config_tags": builtin_config_tags, } tools_dict.append(tool_info) return Response().ok(data=tools_dict).__dict__ diff --git a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue index f0be7bbc83..9615e5ea8c 100644 --- a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue +++ b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue @@ -1,7 +1,7 @@