Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def ruler(
judge_model: str = "openai/o3",
extra_litellm_params: dict | None = None,
rubric: str = DEFAULT_RUBRIC,
tools: list | None = None,
*,
debug: bool = False,
) -> list[TrajectoryScore]:
Expand All @@ -81,6 +82,8 @@ async def ruler(
extra_litellm_params: Additional parameters to pass to LiteLLM completion.
Can include temperature, max_tokens, etc.
rubric: The grading rubric. The default rubric works well for most tasks.
tools: Optional list of tool definitions available to the agent. When provided,
the judge will see which tools were available when evaluating tool usage.
debug: If True, pretty-print the judge's reasoning to help understand scores.

Returns:
Expand Down Expand Up @@ -137,6 +140,12 @@ async def ruler(
"<context>\n" + json.dumps(common_prefix_messages) + "\n</context>\n\n"
)

# Include available tools so the judge knows which tool calls are valid
if tools:
user_text += (
"<available_tools>\n" + json.dumps(tools) + "\n</available_tools>\n\n"
)

# Serialize each trajectory (minus the common prefix) for the judge.
# If all trajectories are identical, only serialize one full trajectory to save tokens.
serialized_trajectories: List[str] = []
Expand Down Expand Up @@ -292,13 +301,17 @@ async def ruler_score_group(
message_lists.append(traj.messages())
traj.metrics["independent_reward"] = traj.reward

# Extract tools from first trajectory (they should all be the same)
tools = new_trajectories[0].tools if new_trajectories else None

try:
# Call the core ruler function to get scores
scores = await ruler(
message_lists,
judge_model=judge_model,
extra_litellm_params=extra_litellm_params,
rubric=rubric,
tools=tools,
debug=debug,
)
except Exception as e:
Expand Down