From 16fad4697980773db03cb190dd0afff5e3fc238b Mon Sep 17 00:00:00 2001 From: krisztianfekete Date: Thu, 2 Apr 2026 17:25:14 +0200 Subject: [PATCH] improve MCP server --- src/agentevals/mcp_server.py | 396 ++++++++++++++++++++++++++++------- tests/test_mcp_server.py | 343 ++++++++++++++++++++++++++++++ 2 files changed, 659 insertions(+), 80 deletions(-) create mode 100644 tests/test_mcp_server.py diff --git a/src/agentevals/mcp_server.py b/src/agentevals/mcp_server.py index 392a727..21527a5 100644 --- a/src/agentevals/mcp_server.py +++ b/src/agentevals/mcp_server.py @@ -6,6 +6,7 @@ import httpx from mcp.server import FastMCP +from pydantic import BaseModel, Field from agentevals.config import EvalRunConfig from agentevals.runner import run_evaluation @@ -13,6 +14,119 @@ _DEFAULT_SERVER_URL = "http://localhost:8001" +# --------------------------------------------------------------------------- +# MCP tool response models +# --------------------------------------------------------------------------- + + +class MetricInfoResponse(BaseModel): + name: str + category: str + requires_eval_set: bool + requires_llm: bool + requires_gcp: bool + requires_rubrics: bool + description: str + working: bool + + +class MetricScoreResponse(BaseModel): + metric: str + score: float | None = None + status: str + error: str | None = None + + +class TraceEvalResponse(BaseModel): + trace_id: str + num_invocations: int + metrics: list[MetricScoreResponse] + warnings: list[str] | None = None + + +class EvaluateTracesResponse(BaseModel): + passed: bool + traces: list[TraceEvalResponse] + errors: list[str] | None = None + + +class SessionSummaryResponse(BaseModel): + session_id: str + is_complete: bool + span_count: int + started_at: str + + +class ToolCallResponse(BaseModel): + tool: str + args: dict[str, Any] = Field(default_factory=dict) + + +class InvocationSummaryResponse(BaseModel): + user: str + response: str + tool_calls: list[ToolCallResponse] + + +class SummarizeSessionResponse(BaseModel): + session_id: str + num_spans: int + num_invocations: int = 0 + invocations: list[InvocationSummaryResponse] + + +class SessionEvalResultResponse(BaseModel): + session_id: str + trace_id: str | None = None + num_invocations: int | None = None + metric_results: list[dict[str, Any]] | None = None + error: str | None = None + + +class EvaluateSessionsResponse(BaseModel): + golden_session_id: str + eval_set_id: str + results: list[SessionEvalResultResponse] + + +# --------------------------------------------------------------------------- +# Result transformation +# --------------------------------------------------------------------------- + + +def summarize_run_result(result) -> EvaluateTracesResponse: + """Transform a RunResult into the MCP tool response shape.""" + traces = [] + for tr in result.trace_results: + metrics = [ + MetricScoreResponse( + metric=mr.metric_name, + score=mr.score, + status=mr.eval_status, + error=mr.error if mr.error else None, + ) + for mr in tr.metric_results + ] + traces.append( + TraceEvalResponse( + trace_id=tr.trace_id, + num_invocations=tr.num_invocations, + metrics=metrics, + warnings=tr.conversion_warnings if tr.conversion_warnings else None, + ) + ) + return EvaluateTracesResponse( + passed=all(m.status != "FAILED" for t in traces for m in t.metrics), + traces=traces, + errors=result.errors if result.errors else None, + ) + + +# --------------------------------------------------------------------------- +# Server factory +# --------------------------------------------------------------------------- + + def create_server(server_url: str | None = None, **fastmcp_kwargs: Any) -> FastMCP: """Build the FastMCP server. Extra keyword arguments are passed to :class:`FastMCP` (e.g. ``host``, ``port``).""" mcp = FastMCP("agentevals", **fastmcp_kwargs) @@ -49,35 +163,51 @@ async def _post(path: str, body: dict) -> Any: except httpx.HTTPStatusError as exc: raise RuntimeError(f"Server error {exc.response.status_code}: {exc.response.text}") from exc - def _summarize_run_result(result) -> dict[str, Any]: - traces = [] - for tr in result.trace_results: - traces.append( - { - "trace_id": tr.trace_id, - "num_invocations": tr.num_invocations, - "metrics": [ - { - "metric": mr.metric_name, - "score": mr.score, - "status": mr.eval_status, - **({"error": mr.error} if mr.error else {}), - } - for mr in tr.metric_results - ], - **({"warnings": tr.conversion_warnings} if tr.conversion_warnings else {}), - } - ) - return { - "passed": all(mr["status"] != "FAILED" for tr in traces for mr in tr["metrics"]), - "traces": traces, - **({"errors": result.errors} if result.errors else {}), - } - @mcp.tool() - async def list_metrics() -> list[dict[str, Any]]: - """List all available evaluation metrics with their descriptions and requirements.""" - return await _get("/api/metrics") + async def list_metrics() -> list[MetricInfoResponse]: + """List all available evaluation metrics with their descriptions and requirements. + + Call this first to discover which metrics you can pass to evaluate_traces + or evaluate_sessions. Each metric has requirement flags that indicate what + it needs to produce results (an eval set for comparison, an LLM judge, GCP + credentials, or rubric configuration). + + Returns: + A list of metric objects, each containing: + - name: metric identifier to pass to evaluation tools + - category: grouping such as "trajectory", "response", "safety", "quality" + - requires_eval_set: whether a golden eval set file is needed + - requires_llm: whether an LLM judge model is needed + - requires_gcp: whether GCP/Vertex AI credentials are needed + - requires_rubrics: whether rubric configuration is needed + - description: what the metric measures + - working: whether the metric is currently functional + + Common metrics: + "tool_trajectory_avg_score": compares actual tool call sequences + against expected trajectory from an eval set. + "response_match_score": ROUGE-1 text similarity against expected + response (requires eval set). + "hallucinations_v1": detects hallucinated information (requires + LLM judge). + "final_response_match_v2": LLM-based response comparison (requires + eval set and LLM judge). + "safety_v1": safety assessment via Vertex AI (requires GCP). + """ + data = await _get("/api/metrics") + return [ + MetricInfoResponse( + name=m["name"], + category=m["category"], + requires_eval_set=m["requiresEvalSet"], + requires_llm=m["requiresLLM"], + requires_gcp=m["requiresGCP"], + requires_rubrics=m["requiresRubrics"], + description=m["description"], + working=m["working"], + ) + for m in data + ] @mcp.tool() async def evaluate_traces( @@ -88,20 +218,54 @@ async def evaluate_traces( judge_model: str | None = None, threshold: float | None = None, eval_config_file: str | None = None, - ) -> dict[str, Any]: - """Evaluate one or more local agent trace files. + ) -> EvaluateTracesResponse: + """Evaluate one or more local agent trace files against selected metrics. + + This is the primary offline evaluation tool. It loads trace files from disk, + converts them to the internal invocation format, and runs each requested + metric. Does not require the agentevals server to be running. - Does not require the agentevals server to be running. Returns a flat summary - with a top-level 'passed' boolean and per-trace metric scores. + Typical workflow: + 1. Call list_metrics to discover available metrics and their requirements. + 2. Call evaluate_traces with trace file paths and chosen metrics. + 3. Check the top-level "passed" field for a quick pass/fail summary. + 4. Inspect per-trace metric scores and errors for details. Args: - trace_files: Absolute paths to Jaeger JSON or OTLP JSON/JSONL trace files. - metrics: Metric names to evaluate. Use list_metrics to see available options. - trace_format: "jaeger-json" or "otlp-json". - eval_set_file: Path to a golden eval set JSON for comparison metrics. - judge_model: LLM model for judge-based metrics (e.g. "gemini-2.5-flash"). - threshold: Score threshold for PASS/FAIL classification (0.0–1.0). - eval_config_file: Path to an eval config YAML file with custom evaluators. + trace_files: Absolute paths to trace files on disk. Supports Jaeger + JSON (.json) and OTLP JSON/JSONL (.jsonl) formats. Each file may + contain one or more traces. + metrics: Metric names to evaluate (from list_metrics). Defaults to + ["tool_trajectory_avg_score"] if not specified. + trace_format: Format of the trace files. Either "jaeger-json" + (default) or "otlp-json". Use "otlp-json" for .jsonl files + exported by OpenTelemetry. + eval_set_file: Absolute path to a golden eval set JSON file (ADK + EvalSet format). Required by comparison metrics such as + "tool_trajectory_avg_score" and "response_match_score". + judge_model: LLM model name for judge-based metrics, for example + "gemini-2.5-flash" or "gemini-2.0-flash". Required by metrics + like "hallucinations_v1" and "final_response_match_v2". + threshold: Score threshold for PASS/FAIL classification, between + 0.0 and 1.0. Metric scores below this value are marked FAILED. + eval_config_file: Absolute path to an eval config YAML file that + defines custom evaluators. When provided, its settings are + merged with the other arguments (explicit arguments take + precedence over the config file). + + Returns: + An EvaluateTracesResponse with: + - passed: true if no metric across any trace has status "FAILED" + - traces: list of per-trace results, each containing: + - trace_id: identifier of the evaluated trace + - num_invocations: number of agent invocations in the trace + - metrics: list of metric results, each with: + - metric: the metric name + - score: numeric score (0.0 to 1.0), or null if not scored + - status: "PASSED", "FAILED", or "NOT_EVALUATED" + - error: error message if the metric failed to run + - warnings: conversion warnings, if any + - errors: top-level errors (e.g. trace files that failed to load) """ if metrics is None: metrics = ["tool_trajectory_avg_score"] @@ -128,38 +292,67 @@ async def evaluate_traces( threshold=threshold, ) result = await run_evaluation(config) - return _summarize_run_result(result) + return summarize_run_result(result) @mcp.tool() - async def list_sessions(limit: int = 20) -> list[dict[str, Any]]: - """List streaming trace sessions, most recent first. + async def list_sessions(limit: int = 20) -> list[SessionSummaryResponse]: + """List recent streaming trace sessions, ordered most recent first. + + Use this to discover session IDs for summarize_session or + evaluate_sessions. Sessions are created when agents stream traces to the + agentevals server via the SDK or an OpenTelemetry exporter. - Requires agentevals serve to be running. + Requires the agentevals server to be running (start with: + uv run agentevals serve --dev). Args: - limit: Maximum number of sessions to return (default: 20). + limit: Maximum number of sessions to return. Defaults to 20. + + Returns: + A list of SessionSummaryResponse objects, each containing: + - session_id: unique identifier to pass to other session tools + - is_complete: whether the session has finished receiving spans + - span_count: number of OpenTelemetry spans recorded + - started_at: ISO 8601 timestamp of when the session began """ sessions = await _get("/api/streaming/sessions") sessions.sort(key=lambda s: s.get("startedAt", ""), reverse=True) return [ - { - "sessionId": s["sessionId"], - "isComplete": s["isComplete"], - "spanCount": s["spanCount"], - "startedAt": s["startedAt"], - } + SessionSummaryResponse( + session_id=s["sessionId"], + is_complete=s["isComplete"], + span_count=s["spanCount"], + started_at=s["startedAt"], + ) for s in sessions[:limit] ] @mcp.tool() - async def summarize_session(session_id: str) -> dict[str, Any]: - """Get a structured summary of a session's invocations, tool calls, and messages. + async def summarize_session(session_id: str) -> SummarizeSessionResponse: + """Get a structured summary of a session showing its invocations, tool + calls, and messages in human-readable form. + + Use this to understand what an agent did during a session before running + evaluation. Parses the raw OpenTelemetry trace and extracts the + conversation flow: what the user said, what tools the agent called, and + how it responded. - Parses the raw trace and returns human-readable invocation data: user messages, - agent responses, and tool calls made. For the full span data, use get_session_trace. + Requires the agentevals server to be running. Args: - session_id: Session ID from list_sessions. + session_id: Session ID obtained from list_sessions. + + Returns: + A SummarizeSessionResponse with: + - session_id: the requested session ID + - num_spans: total OpenTelemetry spans in the session + - num_invocations: number of agent invocations extracted + - invocations: chronological list of invocations, each containing: + - user: the user's input text + - response: the agent's final response text + - tool_calls: tools the agent called, each with: + - tool: the tool/function name + - args: arguments passed to the tool """ from agentevals.converter import convert_traces from agentevals.loader.otlp import OtlpJsonLoader @@ -172,7 +365,11 @@ async def summarize_session(session_id: str) -> dict[str, Any]: traces = OtlpJsonLoader().load(tmp_path) if not traces: - return {"session_id": session_id, "num_spans": raw["numSpans"], "invocations": []} + return SummarizeSessionResponse( + session_id=session_id, + num_spans=raw["numSpans"], + invocations=[], + ) invocations = [] for conv in convert_traces(traces): @@ -180,26 +377,25 @@ async def summarize_session(session_id: str) -> dict[str, Any]: tool_calls = [] if inv.intermediate_data: tool_calls = [ - {"tool": tu.name, "args": getattr(tu, "args", {})} for tu in inv.intermediate_data.tool_uses + ToolCallResponse(tool=tu.name, args=getattr(tu, "args", {})) + for tu in inv.intermediate_data.tool_uses ] invocations.append( - { - "user": next((p.text for p in inv.user_content.parts if p.text), "") - if inv.user_content - else "", - "response": next((p.text for p in inv.final_response.parts if p.text), "") + InvocationSummaryResponse( + user=next((p.text for p in inv.user_content.parts if p.text), "") if inv.user_content else "", + response=next((p.text for p in inv.final_response.parts if p.text), "") if inv.final_response else "", - "tool_calls": tool_calls, - } + tool_calls=tool_calls, + ) ) - return { - "session_id": session_id, - "num_spans": raw["numSpans"], - "num_invocations": len(invocations), - "invocations": invocations, - } + return SummarizeSessionResponse( + session_id=session_id, + num_spans=raw["numSpans"], + num_invocations=len(invocations), + invocations=invocations, + ) @mcp.tool() async def evaluate_sessions( @@ -207,24 +403,50 @@ async def evaluate_sessions( metrics: list[str] | None = None, judge_model: str = "gemini-2.5-flash", eval_set_id: str | None = None, - ) -> dict[str, Any]: + ) -> EvaluateSessionsResponse: """Evaluate all completed sessions against a golden reference session. - The server builds the eval set from the golden session automatically — no file - creation or pre-existing eval set needed. Call list_sessions first to find session IDs. + This is the primary tool for regression testing streamed agent sessions. + The server automatically builds an eval set from the golden session's + trace, then evaluates every other completed session against it. No file + creation or pre-existing eval set is needed. + + Typical workflow: + 1. Call list_sessions to find session IDs. + 2. Call summarize_session on a candidate to verify it represents the + expected "golden" behavior. + 3. Call evaluate_sessions with that session as the golden reference. + 4. Inspect per-session results to find regressions. - Requires agentevals serve to be running. + Requires the agentevals server to be running. Args: - golden_session_id: Session ID of the reference/golden run. - metrics: Metric names to evaluate. Use list_metrics to see available options. - judge_model: LLM model for judge-based metrics. - eval_set_id: A label for the eval set built from the golden session. You can use - any string or omit it — a default will be generated automatically. + golden_session_id: Session ID of the reference run. All other + completed sessions will be compared against this one. + metrics: Metric names to evaluate (from list_metrics). Defaults to + ["tool_trajectory_avg_score"]. Only metrics that support eval + set comparison are meaningful here. + judge_model: LLM model for judge-based metrics. Defaults to + "gemini-2.5-flash". + eval_set_id: A label for the eval set built from the golden session. + Any string is accepted. If omitted, a default is generated from + the golden session ID. + + Returns: + An EvaluateSessionsResponse with: + - golden_session_id: the reference session used + - eval_set_id: the label assigned to the generated eval set + - results: list of per-session evaluation results, each containing: + - session_id: the evaluated session + - trace_id: trace identifier, if available + - num_invocations: number of invocations evaluated, if available + - metric_results: list of metric score dicts (with metricName, + score, evalStatus, perInvocationScores, error, details fields) + - error: error message if evaluation of this session failed """ if metrics is None: metrics = ["tool_trajectory_avg_score"] - return await _post( + data = await _post( "/api/streaming/evaluate-sessions", { "golden_session_id": golden_session_id, @@ -233,5 +455,19 @@ async def evaluate_sessions( "judge_model": judge_model, }, ) + return EvaluateSessionsResponse( + golden_session_id=data["goldenSessionId"], + eval_set_id=data["evalSetId"], + results=[ + SessionEvalResultResponse( + session_id=r["sessionId"], + trace_id=r.get("traceId"), + num_invocations=r.get("numInvocations"), + metric_results=r.get("metricResults"), + error=r.get("error"), + ) + for r in data["results"] + ], + ) return mcp diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..614be24 --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,343 @@ +"""Tests for MCP server response models and transformation logic.""" + +from __future__ import annotations + +import pytest + +from agentevals.mcp_server import ( + EvaluateSessionsResponse, + EvaluateTracesResponse, + InvocationSummaryResponse, + MetricInfoResponse, + MetricScoreResponse, + SessionEvalResultResponse, + SessionSummaryResponse, + SummarizeSessionResponse, + ToolCallResponse, + TraceEvalResponse, + summarize_run_result, +) +from agentevals.runner import MetricResult, RunResult, TraceResult + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_metric_result( + name: str = "tool_trajectory_avg_score", + score: float | None = 0.85, + status: str = "PASSED", + error: str | None = None, +) -> MetricResult: + return MetricResult(metric_name=name, score=score, eval_status=status, error=error) + + +def _make_trace_result( + trace_id: str = "trace-1", + num_invocations: int = 2, + metric_results: list[MetricResult] | None = None, + warnings: list[str] | None = None, +) -> TraceResult: + return TraceResult( + trace_id=trace_id, + num_invocations=num_invocations, + metric_results=[_make_metric_result()] if metric_results is None else metric_results, + conversion_warnings=[] if warnings is None else warnings, + ) + + +def _make_run_result( + trace_results: list[TraceResult] | None = None, + errors: list[str] | None = None, +) -> RunResult: + return RunResult( + trace_results=[_make_trace_result()] if trace_results is None else trace_results, + errors=[] if errors is None else errors, + ) + + +# --------------------------------------------------------------------------- +# summarize_run_result +# --------------------------------------------------------------------------- + + +class TestSummarizeRunResult: + def test_single_passing_trace(self): + result = summarize_run_result(_make_run_result()) + + assert isinstance(result, EvaluateTracesResponse) + assert result.passed is True + assert len(result.traces) == 1 + assert result.traces[0].trace_id == "trace-1" + assert result.traces[0].num_invocations == 2 + assert result.traces[0].metrics[0].metric == "tool_trajectory_avg_score" + assert result.traces[0].metrics[0].score == 0.85 + assert result.traces[0].metrics[0].status == "PASSED" + assert result.traces[0].metrics[0].error is None + assert result.traces[0].warnings is None + assert result.errors is None + + def test_failed_metric_sets_passed_false(self): + run = _make_run_result( + trace_results=[_make_trace_result(metric_results=[_make_metric_result(status="FAILED", score=0.3)])] + ) + result = summarize_run_result(run) + + assert result.passed is False + + def test_mixed_pass_fail_across_traces(self): + run = _make_run_result( + trace_results=[ + _make_trace_result( + trace_id="t1", + metric_results=[_make_metric_result(status="PASSED")], + ), + _make_trace_result( + trace_id="t2", + metric_results=[_make_metric_result(status="FAILED", score=0.2)], + ), + ] + ) + result = summarize_run_result(run) + + assert result.passed is False + assert len(result.traces) == 2 + + def test_not_evaluated_does_not_cause_failure(self): + run = _make_run_result( + trace_results=[_make_trace_result(metric_results=[_make_metric_result(status="NOT_EVALUATED", score=None)])] + ) + result = summarize_run_result(run) + + assert result.passed is True + + def test_multiple_metrics_per_trace(self): + run = _make_run_result( + trace_results=[ + _make_trace_result( + metric_results=[ + _make_metric_result(name="tool_trajectory_avg_score", score=0.9, status="PASSED"), + _make_metric_result(name="response_match_score", score=0.7, status="PASSED"), + ] + ) + ] + ) + result = summarize_run_result(run) + + assert result.passed is True + assert len(result.traces[0].metrics) == 2 + assert result.traces[0].metrics[0].metric == "tool_trajectory_avg_score" + assert result.traces[0].metrics[1].metric == "response_match_score" + + def test_conversion_warnings_included(self): + run = _make_run_result(trace_results=[_make_trace_result(warnings=["Missing root span", "Unknown scope"])]) + result = summarize_run_result(run) + + assert result.traces[0].warnings == ["Missing root span", "Unknown scope"] + + def test_empty_warnings_becomes_none(self): + run = _make_run_result(trace_results=[_make_trace_result(warnings=[])]) + result = summarize_run_result(run) + + assert result.traces[0].warnings is None + + def test_errors_included(self): + run = _make_run_result(errors=["Failed to load trace file 'bad.json'"]) + result = summarize_run_result(run) + + assert result.errors == ["Failed to load trace file 'bad.json'"] + + def test_empty_errors_becomes_none(self): + run = _make_run_result(errors=[]) + result = summarize_run_result(run) + + assert result.errors is None + + def test_metric_error_preserved(self): + run = _make_run_result( + trace_results=[ + _make_trace_result( + metric_results=[ + _make_metric_result( + status="NOT_EVALUATED", + score=None, + error="Metric requires eval set", + ) + ] + ) + ] + ) + result = summarize_run_result(run) + + assert result.traces[0].metrics[0].error == "Metric requires eval set" + + def test_empty_metric_error_becomes_none(self): + run = _make_run_result(trace_results=[_make_trace_result(metric_results=[_make_metric_result(error="")])]) + result = summarize_run_result(run) + + assert result.traces[0].metrics[0].error is None + + def test_no_traces(self): + run = _make_run_result(trace_results=[]) + result = summarize_run_result(run) + + assert result.passed is True + assert result.traces == [] + + def test_no_metrics_on_trace(self): + run = _make_run_result(trace_results=[_make_trace_result(metric_results=[])]) + result = summarize_run_result(run) + + assert result.passed is True + assert result.traces[0].metrics == [] + + +# --------------------------------------------------------------------------- +# Response model serialization +# --------------------------------------------------------------------------- + + +class TestMetricInfoResponse: + def test_from_api_shaped_data(self): + api_data = { + "name": "tool_trajectory_avg_score", + "category": "trajectory", + "requiresEvalSet": True, + "requiresLLM": False, + "requiresGCP": False, + "requiresRubrics": False, + "description": "Compares tool call sequences", + "working": True, + } + model = MetricInfoResponse( + name=api_data["name"], + category=api_data["category"], + requires_eval_set=api_data["requiresEvalSet"], + requires_llm=api_data["requiresLLM"], + requires_gcp=api_data["requiresGCP"], + requires_rubrics=api_data["requiresRubrics"], + description=api_data["description"], + working=api_data["working"], + ) + + assert model.name == "tool_trajectory_avg_score" + assert model.requires_eval_set is True + assert model.requires_llm is False + dumped = model.model_dump() + assert "requires_eval_set" in dumped + assert "requiresEvalSet" not in dumped + + +class TestSessionSummaryResponse: + def test_from_api_shaped_data(self): + api_data = { + "sessionId": "sess-abc", + "isComplete": True, + "spanCount": 42, + "startedAt": "2025-01-15T10:30:00Z", + } + model = SessionSummaryResponse( + session_id=api_data["sessionId"], + is_complete=api_data["isComplete"], + span_count=api_data["spanCount"], + started_at=api_data["startedAt"], + ) + + assert model.session_id == "sess-abc" + assert model.is_complete is True + assert model.span_count == 42 + dumped = model.model_dump() + assert "session_id" in dumped + assert "sessionId" not in dumped + + +class TestEvaluateSessionsResponse: + def test_from_api_shaped_data(self): + api_data = { + "goldenSessionId": "golden-1", + "evalSetId": "eval-golden-1", + "results": [ + { + "sessionId": "sess-2", + "traceId": "t2", + "numInvocations": 3, + "metricResults": [ + {"metricName": "tool_trajectory_avg_score", "score": 0.9, "evalStatus": "PASSED"} + ], + "error": None, + }, + { + "sessionId": "sess-3", + "traceId": None, + "numInvocations": None, + "metricResults": None, + "error": "Session has no spans", + }, + ], + } + model = EvaluateSessionsResponse( + golden_session_id=api_data["goldenSessionId"], + eval_set_id=api_data["evalSetId"], + results=[ + SessionEvalResultResponse( + session_id=r["sessionId"], + trace_id=r.get("traceId"), + num_invocations=r.get("numInvocations"), + metric_results=r.get("metricResults"), + error=r.get("error"), + ) + for r in api_data["results"] + ], + ) + + assert model.golden_session_id == "golden-1" + assert len(model.results) == 2 + assert model.results[0].session_id == "sess-2" + assert model.results[0].metric_results[0]["score"] == 0.9 + assert model.results[1].error == "Session has no spans" + assert model.results[1].metric_results is None + + +class TestSummarizeSessionResponse: + def test_with_invocations(self): + model = SummarizeSessionResponse( + session_id="sess-1", + num_spans=15, + num_invocations=2, + invocations=[ + InvocationSummaryResponse( + user="deploy nginx", + response="I'll deploy nginx using helm.", + tool_calls=[ + ToolCallResponse(tool="helm_install", args={"chart": "nginx"}), + ToolCallResponse(tool="kubectl_get", args={"resource": "pods"}), + ], + ), + InvocationSummaryResponse( + user="check status", + response="All pods are running.", + tool_calls=[], + ), + ], + ) + + assert model.num_invocations == 2 + assert model.invocations[0].tool_calls[0].tool == "helm_install" + assert model.invocations[0].tool_calls[0].args == {"chart": "nginx"} + assert model.invocations[1].tool_calls == [] + + def test_empty_session(self): + model = SummarizeSessionResponse( + session_id="sess-empty", + num_spans=0, + invocations=[], + ) + + assert model.num_invocations == 0 + assert model.invocations == [] + + def test_tool_call_default_args(self): + tc = ToolCallResponse(tool="my_tool") + assert tc.args == {}