Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions src/stirrup/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextvars
import glob as glob_module
import inspect
import json
import logging
import re
import signal
Expand Down Expand Up @@ -177,6 +178,45 @@ def _get_model_speed_stats(messages: list[list[ChatMessage]], model_slug: str) -
}


type JsonSchema = dict[str, object]


def _coerce_embedded_json(value: object, schema: JsonSchema) -> object:
schema_type = schema.get("type")

if isinstance(value, str) and schema_type in {"object", "array"}:
try:
parsed = json.loads(value)
except json.JSONDecodeError:
return value

if (schema_type == "object" and isinstance(parsed, dict)) or (schema_type == "array" and isinstance(parsed, list)):
value = parsed
else:
return value

if isinstance(value, dict):
properties = schema.get("properties")
property_schemas = properties if isinstance(properties, dict) else {}
additional = schema.get("additionalProperties")
return {
key: _coerce_embedded_json(
item,
property_schemas.get(key)
if isinstance(property_schemas.get(key), dict)
else additional if isinstance(additional, dict) else {},
)
for key, item in value.items()
}

if isinstance(value, list):
items = schema.get("items")
item_schema = items if isinstance(items, dict) else {}
return [_coerce_embedded_json(item, item_schema) for item in value]

return value


class SubAgentParams(BaseModel):
"""Parameters for sub-agent tool invocation."""

Expand Down Expand Up @@ -232,6 +272,7 @@ def __init__(
run_sync_in_thread: bool = True,
text_only_tool_responses: bool = True,
block_successive_assistant_messages: bool = True,
normalize_embedded_json_tool_args: bool = False,
# Subagent options
share_parent_exec_env: bool = False,
# Logging
Expand All @@ -256,6 +297,9 @@ def __init__(
block_successive_assistant_messages: If True (default), automatically inject a continue
message when assistant responds without tool calls to
prevent successive assistant messages.
normalize_embedded_json_tool_args: If True, parse stringified JSON nested inside
tool arguments when the schema expects an object
or array.
share_parent_exec_env: When True and used as a subagent, share the parent's code
execution environment instead of creating a new one. This
provides better performance (no file copying) and allows
Expand Down Expand Up @@ -283,6 +327,7 @@ def __init__(
self._run_sync_in_thread = run_sync_in_thread
self._text_only_tool_responses = text_only_tool_responses
self._block_successive_assistant_messages = block_successive_assistant_messages
self._normalize_embedded_json_tool_args = normalize_embedded_json_tool_args
self._share_parent_exec_env = share_parent_exec_env

# Logger (can be passed in or created here)
Expand Down Expand Up @@ -965,7 +1010,12 @@ async def run_tool(self, tool_call: ToolCall, run_metadata: dict[str, list[Any]]
try:
# Normalize empty arguments to valid empty JSON object
args = tool_call.arguments if tool_call.arguments and tool_call.arguments.strip() else "{}"
params = tool.parameters.model_validate_json(args)
if self._normalize_embedded_json_tool_args:
raw_args = json.loads(args)
normalized_args = _coerce_embedded_json(raw_args, tool.parameters.model_json_schema())
params = tool.parameters.model_validate(normalized_args)
else:
params = tool.parameters.model_validate_json(args)

# Set parent depth for sub-agent tools to read
prev_depth = _PARENT_DEPTH.set(self._logger.depth)
Expand All @@ -984,7 +1034,7 @@ async def run_tool(self, tool_call: ToolCall, run_metadata: dict[str, list[Any]]
# Store metadata if present
if result.metadata is not None:
run_metadata[tool_call.name].append(result.metadata)
except ValidationError:
except (json.JSONDecodeError, ValidationError):
LOGGER.debug(
"LLMClient tried to use the tool %s but the tool arguments are not valid: %r",
tool_call.name,
Expand Down
54 changes: 54 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,60 @@ def image_executor(_params: EmptyParams) -> ToolResult:
assert tool_message.content[0].mime_type == "image/png"


async def test_run_tool_normalizes_embedded_json_args_when_enabled() -> None:
"""Test run_tool normalizes stringified nested JSON only when enabled."""

class NestedParams(BaseModel):
payload: dict[str, int]
items: list[int]

def nested_executor(params: NestedParams) -> ToolResult:
return ToolResult(content=f"payload={params.payload};items={params.items}")

nested_tool = Tool[NestedParams, None](
name="nested_tool",
description="Parse nested JSON",
parameters=NestedParams,
executor=nested_executor, # ty: ignore[invalid-argument-type]
)
tool_call = ToolCall(
name="nested_tool",
arguments='{"payload":"{\\"count\\":1}","items":"[1,2]"}',
tool_call_id="call_1",
)

enabled_agent = Agent(
client=MockLLMClient([]),
name="test-agent-enabled",
max_turns=1,
tools=[nested_tool],
finish_tool=SIMPLE_FINISH_TOOL,
normalize_embedded_json_tool_args=True,
)

async with enabled_agent.session() as session:
enabled_message = await session.run_tool(tool_call, run_metadata={})

assert enabled_message.success is True
assert enabled_message.args_was_valid is True
assert enabled_message.content == "payload={'count': 1};items=[1, 2]"

disabled_agent = Agent(
client=MockLLMClient([]),
name="test-agent-disabled",
max_turns=1,
tools=[nested_tool],
finish_tool=SIMPLE_FINISH_TOOL,
)

async with disabled_agent.session() as session:
disabled_message = await session.run_tool(tool_call, run_metadata={})

assert disabled_message.success is False
assert disabled_message.args_was_valid is False
assert disabled_message.content == "Tool arguments are not valid"


async def test_agent_invalid_tool_call() -> None:
"""Test agent handles invalid tool calls gracefully."""
# Create mock responses
Expand Down