diff --git a/src/stirrup/core/agent.py b/src/stirrup/core/agent.py index 54fbfa7..52d3c2c 100644 --- a/src/stirrup/core/agent.py +++ b/src/stirrup/core/agent.py @@ -2,6 +2,7 @@ import contextvars import glob as glob_module import inspect +import json import logging import re import signal @@ -177,6 +178,45 @@ def _get_model_speed_stats(messages: list[list[ChatMessage]], model_slug: str) - } +type JsonSchema = dict[str, object] + + +def _coerce_embedded_json(value: object, schema: JsonSchema) -> object: + schema_type = schema.get("type") + + if isinstance(value, str) and schema_type in {"object", "array"}: + try: + parsed = json.loads(value) + except json.JSONDecodeError: + return value + + if (schema_type == "object" and isinstance(parsed, dict)) or (schema_type == "array" and isinstance(parsed, list)): + value = parsed + else: + return value + + if isinstance(value, dict): + properties = schema.get("properties") + property_schemas = properties if isinstance(properties, dict) else {} + additional = schema.get("additionalProperties") + return { + key: _coerce_embedded_json( + item, + property_schemas.get(key) + if isinstance(property_schemas.get(key), dict) + else additional if isinstance(additional, dict) else {}, + ) + for key, item in value.items() + } + + if isinstance(value, list): + items = schema.get("items") + item_schema = items if isinstance(items, dict) else {} + return [_coerce_embedded_json(item, item_schema) for item in value] + + return value + + class SubAgentParams(BaseModel): """Parameters for sub-agent tool invocation.""" @@ -232,6 +272,7 @@ def __init__( run_sync_in_thread: bool = True, text_only_tool_responses: bool = True, block_successive_assistant_messages: bool = True, + normalize_embedded_json_tool_args: bool = False, # Subagent options share_parent_exec_env: bool = False, # Logging @@ -256,6 +297,9 @@ def __init__( block_successive_assistant_messages: If True (default), automatically inject a continue message when assistant responds without tool calls to prevent successive assistant messages. + normalize_embedded_json_tool_args: If True, parse stringified JSON nested inside + tool arguments when the schema expects an object + or array. share_parent_exec_env: When True and used as a subagent, share the parent's code execution environment instead of creating a new one. This provides better performance (no file copying) and allows @@ -283,6 +327,7 @@ def __init__( self._run_sync_in_thread = run_sync_in_thread self._text_only_tool_responses = text_only_tool_responses self._block_successive_assistant_messages = block_successive_assistant_messages + self._normalize_embedded_json_tool_args = normalize_embedded_json_tool_args self._share_parent_exec_env = share_parent_exec_env # Logger (can be passed in or created here) @@ -965,7 +1010,12 @@ async def run_tool(self, tool_call: ToolCall, run_metadata: dict[str, list[Any]] try: # Normalize empty arguments to valid empty JSON object args = tool_call.arguments if tool_call.arguments and tool_call.arguments.strip() else "{}" - params = tool.parameters.model_validate_json(args) + if self._normalize_embedded_json_tool_args: + raw_args = json.loads(args) + normalized_args = _coerce_embedded_json(raw_args, tool.parameters.model_json_schema()) + params = tool.parameters.model_validate(normalized_args) + else: + params = tool.parameters.model_validate_json(args) # Set parent depth for sub-agent tools to read prev_depth = _PARENT_DEPTH.set(self._logger.depth) @@ -984,7 +1034,7 @@ async def run_tool(self, tool_call: ToolCall, run_metadata: dict[str, list[Any]] # Store metadata if present if result.metadata is not None: run_metadata[tool_call.name].append(result.metadata) - except ValidationError: + except (json.JSONDecodeError, ValidationError): LOGGER.debug( "LLMClient tried to use the tool %s but the tool arguments are not valid: %r", tool_call.name, diff --git a/tests/test_agent.py b/tests/test_agent.py index d2921fb..e042898 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -262,6 +262,60 @@ def image_executor(_params: EmptyParams) -> ToolResult: assert tool_message.content[0].mime_type == "image/png" +async def test_run_tool_normalizes_embedded_json_args_when_enabled() -> None: + """Test run_tool normalizes stringified nested JSON only when enabled.""" + + class NestedParams(BaseModel): + payload: dict[str, int] + items: list[int] + + def nested_executor(params: NestedParams) -> ToolResult: + return ToolResult(content=f"payload={params.payload};items={params.items}") + + nested_tool = Tool[NestedParams, None]( + name="nested_tool", + description="Parse nested JSON", + parameters=NestedParams, + executor=nested_executor, # ty: ignore[invalid-argument-type] + ) + tool_call = ToolCall( + name="nested_tool", + arguments='{"payload":"{\\"count\\":1}","items":"[1,2]"}', + tool_call_id="call_1", + ) + + enabled_agent = Agent( + client=MockLLMClient([]), + name="test-agent-enabled", + max_turns=1, + tools=[nested_tool], + finish_tool=SIMPLE_FINISH_TOOL, + normalize_embedded_json_tool_args=True, + ) + + async with enabled_agent.session() as session: + enabled_message = await session.run_tool(tool_call, run_metadata={}) + + assert enabled_message.success is True + assert enabled_message.args_was_valid is True + assert enabled_message.content == "payload={'count': 1};items=[1, 2]" + + disabled_agent = Agent( + client=MockLLMClient([]), + name="test-agent-disabled", + max_turns=1, + tools=[nested_tool], + finish_tool=SIMPLE_FINISH_TOOL, + ) + + async with disabled_agent.session() as session: + disabled_message = await session.run_tool(tool_call, run_metadata={}) + + assert disabled_message.success is False + assert disabled_message.args_was_valid is False + assert disabled_message.content == "Tool arguments are not valid" + + async def test_agent_invalid_tool_call() -> None: """Test agent handles invalid tool calls gracefully.""" # Create mock responses