From b510cfa931ac3373f76a70e56e3850a4e526a0c5 Mon Sep 17 00:00:00 2001 From: krisztianfekete Date: Tue, 31 Mar 2026 11:56:32 +0200 Subject: [PATCH] replace _trace_manager variables with fastAPI dependency injection --- src/agentevals/api/app.py | 32 +++---- src/agentevals/api/debug_routes.py | 44 ++++----- src/agentevals/api/dependencies.py | 23 +++++ src/agentevals/api/otlp_app.py | 8 +- src/agentevals/api/otlp_routes.py | 74 +++++++--------- src/agentevals/api/streaming_routes.py | 118 ++++++++++++++----------- tests/integration/conftest.py | 18 ++-- tests/test_api.py | 22 ++--- tests/test_otlp_receiver.py | 74 ++++++---------- 9 files changed, 201 insertions(+), 212 deletions(-) create mode 100644 src/agentevals/api/dependencies.py diff --git a/src/agentevals/api/app.py b/src/agentevals/api/app.py index 3969be8..695ee56 100644 --- a/src/agentevals/api/app.py +++ b/src/agentevals/api/app.py @@ -15,7 +15,6 @@ from ..utils.log_buffer import log_buffer from .debug_routes import debug_router -from .debug_routes import set_trace_manager as set_debug_trace_manager from .routes import router try: @@ -42,11 +41,12 @@ async def lifespan(app: FastAPI): if log_buffer not in ae_logger.handlers: log_buffer.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s")) ae_logger.addHandler(log_buffer) - if _trace_manager: - _trace_manager.start_cleanup_task() + mgr = getattr(app.state, "trace_manager", None) + if mgr: + mgr.start_cleanup_task() yield - if _trace_manager: - await _trace_manager.shutdown() + if mgr: + await mgr.shutdown() ae_logger.removeHandler(log_buffer) @@ -70,27 +70,27 @@ async def lifespan(app: FastAPI): app.include_router(debug_router, prefix="/api/debug") _live_mode = os.getenv("AGENTEVALS_LIVE") == "1" -_trace_manager = None if _live_mode: + from fastapi import Request as _Request from fastapi import WebSocket from ..streaming.ws_server import StreamingTraceManager - from .streaming_routes import set_trace_manager, streaming_router + from .streaming_routes import streaming_router app.include_router(streaming_router, prefix="/api/streaming") - _trace_manager = StreamingTraceManager() - set_trace_manager(_trace_manager) - set_debug_trace_manager(_trace_manager) + app.state.trace_manager = StreamingTraceManager() @app.websocket("/ws/traces") async def websocket_endpoint(websocket: WebSocket): - await _trace_manager.handle_connection(websocket) + await websocket.app.state.trace_manager.handle_connection(websocket) @app.get("/stream/ui-updates") - async def ui_updates_stream(): + async def ui_updates_stream(request: _Request): + mgr = request.app.state.trace_manager + async def event_generator(): - queue = _trace_manager.register_sse_client() + queue = mgr.register_sse_client() try: while True: event = await queue.get() @@ -100,7 +100,7 @@ async def event_generator(): except asyncio.CancelledError: pass finally: - _trace_manager.unregister_sse_client(queue) + mgr.unregister_sse_client(queue) return StreamingResponse( event_generator(), @@ -112,10 +112,6 @@ async def event_generator(): ) -def get_trace_manager(): - return _trace_manager - - _static_dir = Path(__file__).parent.parent / "_static" _has_ui = _static_dir.is_dir() and (_static_dir / "index.html").exists() diff --git a/src/agentevals/api/debug_routes.py b/src/agentevals/api/debug_routes.py index f92e47d..4a3d8e6 100644 --- a/src/agentevals/api/debug_routes.py +++ b/src/agentevals/api/debug_routes.py @@ -13,7 +13,7 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING -from fastapi import APIRouter, HTTPException, UploadFile +from fastapi import APIRouter, Depends, HTTPException, UploadFile from fastapi import File as FastAPIFile from fastapi.responses import StreamingResponse from pydantic import BaseModel @@ -21,6 +21,7 @@ from agentevals import __version__ from ..utils.log_buffer import log_buffer +from .dependencies import get_trace_manager, require_trace_manager from .models import DebugLoadData, SessionInfo, StandardResponse, WSSessionCompleteEvent, WSSessionStartedEvent if TYPE_CHECKING: @@ -30,13 +31,6 @@ debug_router = APIRouter() -_trace_manager: StreamingTraceManager | None = None - - -def set_trace_manager(manager: StreamingTraceManager) -> None: - global _trace_manager - _trace_manager = manager - class FrontendDiagnostics(BaseModel): user_description: str = "" @@ -83,12 +77,12 @@ def _collect_environment() -> dict: } -def _collect_sessions() -> list[dict]: - if not _trace_manager: +def _collect_sessions(manager: StreamingTraceManager | None) -> list[dict]: + if not manager: return [] sessions_data = [] - for session in _trace_manager.sessions.values(): + for session in manager.sessions.values(): sessions_data.append( { "session_id": session.session_id, @@ -128,7 +122,10 @@ def _collect_temp_files(session_ids: set[str] | None = None) -> dict[str, str]: @debug_router.post("/bundle") -async def create_debug_bundle(diagnostics: FrontendDiagnostics): +async def create_debug_bundle( + diagnostics: FrontendDiagnostics, + manager: StreamingTraceManager | None = Depends(get_trace_manager), +): timestamp = datetime.now(tz=UTC).strftime("%Y%m%d-%H%M%S") prefix = f"bug-report-{timestamp}" @@ -142,7 +139,7 @@ async def create_debug_bundle(diagnostics: FrontendDiagnostics): } zf.writestr(f"{prefix}/metadata.json", json.dumps(metadata, indent=2)) - sessions = _collect_sessions() + sessions = _collect_sessions(manager) for s in sessions: sid = s["session_id"] zf.writestr( @@ -188,13 +185,10 @@ async def create_debug_bundle(diagnostics: FrontendDiagnostics): @debug_router.post("/load", response_model=StandardResponse[DebugLoadData]) -async def load_debug_bundle(file: UploadFile = FastAPIFile(...)): - if not _trace_manager: - raise HTTPException( - status_code=400, - detail="Live mode is not enabled. Start with: agentevals serve --dev", - ) - +async def load_debug_bundle( + file: UploadFile = FastAPIFile(...), + manager: StreamingTraceManager = Depends(require_trace_manager), +): content = await file.read() try: zf = zipfile.ZipFile(io.BytesIO(content)) @@ -236,9 +230,9 @@ async def load_debug_bundle(file: UploadFile = FastAPIFile(...)): metadata=meta.get("metadata", {}), ) - _trace_manager.sessions[session.session_id] = session + manager.sessions[session.session_id] = session - await _trace_manager.broadcast_to_ui( + await manager.broadcast_to_ui( WSSessionStartedEvent( session=SessionInfo( session_id=session.session_id, @@ -252,10 +246,10 @@ async def load_debug_bundle(file: UploadFile = FastAPIFile(...)): ).model_dump(by_alias=True) ) - invocations_data = await _trace_manager._extract_invocations(session) - await _trace_manager._save_spans_to_temp_file(session) + invocations_data = await manager._extract_invocations(session) + await manager._save_spans_to_temp_file(session) - await _trace_manager.broadcast_to_ui( + await manager.broadcast_to_ui( WSSessionCompleteEvent( session_id=session.session_id, invocations=invocations_data, diff --git a/src/agentevals/api/dependencies.py b/src/agentevals/api/dependencies.py new file mode 100644 index 0000000..452b676 --- /dev/null +++ b/src/agentevals/api/dependencies.py @@ -0,0 +1,23 @@ +"""FastAPI dependency functions for shared services.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import HTTPException, Request + +if TYPE_CHECKING: + from ..streaming.ws_server import StreamingTraceManager + + +def get_trace_manager(request: Request) -> StreamingTraceManager | None: + """Return the StreamingTraceManager or None if live mode is off.""" + return getattr(request.app.state, "trace_manager", None) + + +def require_trace_manager(request: Request) -> StreamingTraceManager: + """Return the StreamingTraceManager, raising 503 if live mode is off.""" + mgr = getattr(request.app.state, "trace_manager", None) + if mgr is None: + raise HTTPException(status_code=503, detail="Live mode not enabled") + return mgr diff --git a/src/agentevals/api/otlp_app.py b/src/agentevals/api/otlp_app.py index 006f821..d04fd1e 100644 --- a/src/agentevals/api/otlp_app.py +++ b/src/agentevals/api/otlp_app.py @@ -8,16 +8,16 @@ from fastapi import FastAPI -from .otlp_routes import otlp_router, set_trace_manager +from .otlp_routes import otlp_router @asynccontextmanager async def lifespan(app: FastAPI): - from .app import get_trace_manager + from .app import app as main_app - mgr = get_trace_manager() + mgr = getattr(main_app.state, "trace_manager", None) if mgr: - set_trace_manager(mgr) + app.state.trace_manager = mgr yield diff --git a/src/agentevals/api/otlp_routes.py b/src/agentevals/api/otlp_routes.py index 9551fd5..5298aaa 100644 --- a/src/agentevals/api/otlp_routes.py +++ b/src/agentevals/api/otlp_routes.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING -from fastapi import APIRouter, Request, Response +from fastapi import APIRouter, Depends, Request, Response from google.protobuf.json_format import MessageToDict from opentelemetry.proto.collector.logs.v1.logs_service_pb2 import ( ExportLogsServiceRequest as LogsServiceRequestPB, @@ -31,6 +31,7 @@ OTEL_SCOPE, OTEL_SCOPE_VERSION, ) +from .dependencies import require_trace_manager from .models import WSSpanReceivedEvent if TYPE_CHECKING: @@ -39,23 +40,17 @@ logger = logging.getLogger(__name__) otlp_router = APIRouter() -_trace_manager: StreamingTraceManager | None = None AGENTEVALS_EVAL_SET_ID = "agentevals.eval_set_id" AGENTEVALS_SESSION_NAME = "agentevals.session_name" -def set_trace_manager(manager: StreamingTraceManager) -> None: - global _trace_manager - _trace_manager = manager - - @otlp_router.post("/v1/traces") -async def receive_traces(request: Request) -> Response: +async def receive_traces( + request: Request, + manager: StreamingTraceManager = Depends(require_trace_manager), +) -> Response: """OTLP HTTP trace receiver (ExportTraceServiceRequest).""" - if not _trace_manager: - return Response(status_code=503, content="Live mode not enabled") - content_type = request.headers.get("content-type", "") if "application/x-protobuf" in content_type: @@ -64,7 +59,7 @@ async def receive_traces(request: Request) -> Response: else: body = await request.json() - await _process_traces(body) + await _process_traces(body, manager) return Response( status_code=200, content='{"partialSuccess":{}}', @@ -73,11 +68,11 @@ async def receive_traces(request: Request) -> Response: @otlp_router.post("/v1/logs") -async def receive_logs(request: Request) -> Response: +async def receive_logs( + request: Request, + manager: StreamingTraceManager = Depends(require_trace_manager), +) -> Response: """OTLP HTTP log receiver (ExportLogsServiceRequest).""" - if not _trace_manager: - return Response(status_code=503, content="Live mode not enabled") - content_type = request.headers.get("content-type", "") if "application/x-protobuf" in content_type: @@ -86,7 +81,7 @@ async def receive_logs(request: Request) -> Response: else: body = await request.json() - await _process_logs(body) + await _process_logs(body, manager) return Response( status_code=200, content='{"partialSuccess":{}}', @@ -94,7 +89,7 @@ async def receive_logs(request: Request) -> Response: ) -async def _process_traces(body: dict) -> None: +async def _process_traces(body: dict, manager: StreamingTraceManager) -> None: """Parse ExportTraceServiceRequest and feed spans to the pipeline.""" for resource_span in body.get("resourceSpans", []): resource_attrs = resource_span.get("resource", {}).get("attributes", []) @@ -112,7 +107,7 @@ async def _process_traces(body: dict) -> None: if not trace_id: continue - session = await _trace_manager.get_or_create_otlp_session(trace_id, metadata) + session = await manager.get_or_create_otlp_session(trace_id, metadata) if not session.can_accept_span(): logger.warning("Session %s at span limit", session.session_id) @@ -120,28 +115,28 @@ async def _process_traces(body: dict) -> None: session.spans.append(span) - extractor = _trace_manager.incremental_extractors.get(session.session_id) + extractor = manager.incremental_extractors.get(session.session_id) if extractor: updates = extractor.process_span(span) for update in updates: update["sessionId"] = session.session_id - await _trace_manager.broadcast_to_ui(update) + await manager.broadcast_to_ui(update) - await _trace_manager.broadcast_to_ui( + await manager.broadcast_to_ui( WSSpanReceivedEvent( session_id=session.session_id, span=span, ).model_dump(by_alias=True) ) - _trace_manager.reset_idle_timer(session.session_id) + manager.reset_idle_timer(session.session_id) if not span.get("parentSpanId"): session.has_root_span = True - _trace_manager.schedule_session_completion(session.session_id) + manager.schedule_session_completion(session.session_id) -async def _process_logs(body: dict) -> None: +async def _process_logs(body: dict, manager: StreamingTraceManager) -> None: """Parse ExportLogsServiceRequest and feed logs to sessions. Logs and spans arrive via separate OTLP exporters (BatchLogRecordProcessor @@ -170,23 +165,22 @@ async def _process_logs(body: dict) -> None: if not trace_id: continue - session = _trace_manager.find_session_by_trace_id(trace_id) + session = manager.find_session_by_trace_id(trace_id) - if not session and session_name and _trace_manager: - active_id = _trace_manager._active_session_for_name.get(session_name) - candidate = _trace_manager.sessions.get(active_id) if active_id else None + if not session and session_name: + active_id = manager._active_session_for_name.get(session_name) + candidate = manager.sessions.get(active_id) if active_id else None if candidate and not candidate.is_complete: candidate.trace_ids.add(trace_id) session = candidate if not session: - if _trace_manager: - _trace_manager.buffer_orphan_log(trace_id, session_name, log_event) - logger.debug( - "Buffered orphan log trace_id=%s session_name=%s", - trace_id[:12], - session_name, - ) + manager.buffer_orphan_log(trace_id, session_name, log_event) + logger.debug( + "Buffered orphan log trace_id=%s session_name=%s", + trace_id[:12], + session_name, + ) continue if not session.can_accept_log(): @@ -197,17 +191,17 @@ async def _process_logs(body: dict) -> None: if session.is_complete: sessions_needing_reextraction.add(session.session_id) else: - _trace_manager.reset_idle_timer(session.session_id) + manager.reset_idle_timer(session.session_id) - extractor = _trace_manager.incremental_extractors.get(session.session_id) + extractor = manager.incremental_extractors.get(session.session_id) if extractor: updates = extractor.process_log(log_event) for update in updates: update["sessionId"] = session.session_id - await _trace_manager.broadcast_to_ui(update) + await manager.broadcast_to_ui(update) for session_id in sessions_needing_reextraction: - _trace_manager.schedule_log_reextraction(session_id) + manager.schedule_log_reextraction(session_id) _GENAI_EVENT_KEYS = {OTEL_GENAI_INPUT_MESSAGES, OTEL_GENAI_OUTPUT_MESSAGES} diff --git a/src/agentevals/api/streaming_routes.py b/src/agentevals/api/streaming_routes.py index 30c12d7..6c43993 100644 --- a/src/agentevals/api/streaming_routes.py +++ b/src/agentevals/api/streaming_routes.py @@ -5,8 +5,9 @@ import asyncio import json import logging +from typing import TYPE_CHECKING -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import FileResponse from pydantic import BaseModel @@ -16,6 +17,7 @@ from ..runner import run_evaluation from ..trace_attrs import OTEL_GENAI_INPUT_MESSAGES, OTEL_GENAI_REQUEST_MODEL from ..utils.log_enrichment import enrich_spans_with_logs +from .dependencies import require_trace_manager from .models import ( CreateEvalSetData, EvaluateSessionsData, @@ -26,18 +28,13 @@ StandardResponse, ) +if TYPE_CHECKING: + from ..streaming.ws_server import StreamingTraceManager + logger = logging.getLogger(__name__) streaming_router = APIRouter() -trace_manager = None - - -def set_trace_manager(manager): - """Set the trace manager instance.""" - global trace_manager - trace_manager = manager - class CreateEvalSetRequest(BaseModel): session_id: str @@ -60,35 +57,16 @@ class GetTraceRequest(BaseModel): session_id: str -@streaming_router.get("/sessions", response_model=StandardResponse[list[SessionInfo]]) -async def list_sessions(): - sessions_data = [] - - for session_id, session in trace_manager.sessions.items(): - info = SessionInfo( - session_id=session_id, - trace_id=session.trace_id, - eval_set_id=session.eval_set_id, - span_count=len(session.spans), - is_complete=session.is_complete, - started_at=session.started_at.isoformat(), - metadata=session.metadata, - invocations=session.invocations if session.is_complete and session.invocations else None, - ) - sessions_data.append(info) - - return StandardResponse(data=sessions_data) - - -@streaming_router.post("/create-eval-set", response_model=StandardResponse[CreateEvalSetData]) -async def create_eval_set_from_session(request: CreateEvalSetRequest): - """Convert a session's trace into an EvalSet.""" - session = trace_manager.sessions.get(request.session_id) +async def _do_create_eval_set( + request: CreateEvalSetRequest, manager: StreamingTraceManager +) -> StandardResponse[CreateEvalSetData]: + """Shared logic for creating an EvalSet from a session's trace.""" + session = manager.sessions.get(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") try: - trace_file = await trace_manager._save_spans_to_temp_file(session) + trace_file = await manager._save_spans_to_temp_file(session) logger.debug( "Session %s: %d spans, %d logs saved to %s", request.session_id, @@ -159,19 +137,52 @@ async def create_eval_set_from_session(request: CreateEvalSetRequest): raise HTTPException(status_code=500, detail=str(exc)) from exc +@streaming_router.get("/sessions", response_model=StandardResponse[list[SessionInfo]]) +async def list_sessions(manager: StreamingTraceManager = Depends(require_trace_manager)): + sessions_data = [] + + for session_id, session in manager.sessions.items(): + info = SessionInfo( + session_id=session_id, + trace_id=session.trace_id, + eval_set_id=session.eval_set_id, + span_count=len(session.spans), + is_complete=session.is_complete, + started_at=session.started_at.isoformat(), + metadata=session.metadata, + invocations=session.invocations if session.is_complete and session.invocations else None, + ) + sessions_data.append(info) + + return StandardResponse(data=sessions_data) + + +@streaming_router.post("/create-eval-set", response_model=StandardResponse[CreateEvalSetData]) +async def create_eval_set_from_session( + request: CreateEvalSetRequest, + manager: StreamingTraceManager = Depends(require_trace_manager), +): + """Convert a session's trace into an EvalSet.""" + return await _do_create_eval_set(request, manager) + + @streaming_router.post("/evaluate-sessions", response_model=StandardResponse[EvaluateSessionsData]) -async def evaluate_sessions(request: EvaluateSessionsRequest): +async def evaluate_sessions( + request: EvaluateSessionsRequest, + manager: StreamingTraceManager = Depends(require_trace_manager), +): """Evaluate all sessions against a golden session converted to EvalSet.""" - golden_session = trace_manager.sessions.get(request.golden_session_id) + golden_session = manager.sessions.get(request.golden_session_id) if not golden_session: raise HTTPException(status_code=404, detail="Golden session not found") try: - eval_set_response = await create_eval_set_from_session( + eval_set_response = await _do_create_eval_set( CreateEvalSetRequest( session_id=request.golden_session_id, eval_set_id=request.eval_set_id, - ) + ), + manager, ) import tempfile @@ -181,19 +192,17 @@ async def evaluate_sessions(request: EvaluateSessionsRequest): eval_set_file.close() sessions_to_evaluate = [ - (session_id, session) for session_id, session in trace_manager.sessions.items() if session.is_complete + (session_id, session) for session_id, session in manager.sessions.items() if session.is_complete ] - logger.info( - "Evaluating %d complete sessions (of %d total)", len(sessions_to_evaluate), len(trace_manager.sessions) - ) + logger.info("Evaluating %d complete sessions (of %d total)", len(sessions_to_evaluate), len(manager.sessions)) sem = asyncio.Semaphore(5) async def eval_one_session(session_id: str, session) -> SessionEvalResult: async with sem: try: - trace_file = await trace_manager._save_spans_to_temp_file(session) + trace_file = await manager._save_spans_to_temp_file(session) config = EvalRunConfig( trace_files=[str(trace_file)], @@ -249,18 +258,22 @@ async def eval_one_session(session_id: str, session) -> SessionEvalResult: @streaming_router.post("/prepare-evaluation", response_model=StandardResponse[PrepareEvaluationData]) -async def prepare_evaluation(request: PrepareEvaluationRequest): +async def prepare_evaluation( + request: PrepareEvaluationRequest, + manager: StreamingTraceManager = Depends(require_trace_manager), +): """Prepare evaluation by saving traces and eval set as downloadable files.""" - golden_session = trace_manager.sessions.get(request.golden_session_id) + golden_session = manager.sessions.get(request.golden_session_id) if not golden_session: raise HTTPException(status_code=404, detail="Golden session not found") try: - eval_set_response = await create_eval_set_from_session( + eval_set_response = await _do_create_eval_set( CreateEvalSetRequest( session_id=request.golden_session_id, eval_set_id=f"golden_{request.golden_session_id}", - ) + ), + manager, ) import os @@ -274,11 +287,11 @@ async def prepare_evaluation(request: PrepareEvaluationRequest): trace_files = [] for session_id in request.session_ids: - session = trace_manager.sessions.get(session_id) + session = manager.sessions.get(session_id) if not session or not session.is_complete: continue - trace_file = await trace_manager._save_spans_to_temp_file(session) + trace_file = await manager._save_spans_to_temp_file(session) trace_files.append( { "session_id": session_id, @@ -320,8 +333,11 @@ async def download_file(filename: str): @streaming_router.post("/get-trace", response_model=StandardResponse[GetTraceData]) -async def get_trace(request: GetTraceRequest): - session = trace_manager.sessions.get(request.session_id) +async def get_trace( + request: GetTraceRequest, + manager: StreamingTraceManager = Depends(require_trace_manager), +): + session = manager.sessions.get(request.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 35b42eb..c63a317 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -39,13 +39,12 @@ async def trace_manager(): @pytest.fixture async def otlp_client(trace_manager): """httpx client → OTLP app via ASGI transport (no real server).""" - from agentevals.api.otlp_routes import otlp_router, set_trace_manager - - set_trace_manager(trace_manager) + from agentevals.api.otlp_routes import otlp_router from fastapi import FastAPI test_app = FastAPI() + test_app.state.trace_manager = trace_manager test_app.include_router(otlp_router) transport = httpx.ASGITransport(app=test_app) @@ -56,13 +55,12 @@ async def otlp_client(trace_manager): @pytest.fixture async def api_client(trace_manager): """httpx client → main app streaming routes via ASGI transport.""" - from agentevals.api.streaming_routes import set_trace_manager, streaming_router - - set_trace_manager(trace_manager) + from agentevals.api.streaming_routes import streaming_router from fastapi import FastAPI test_app = FastAPI() + test_app.state.trace_manager = trace_manager test_app.include_router(streaming_router, prefix="/api/streaming") transport = httpx.ASGITransport(app=test_app) @@ -109,12 +107,12 @@ def live_servers(): importlib.reload(app_module) - from agentevals.api.app import app, get_trace_manager + from agentevals.api.app import app from agentevals.api.otlp_app import otlp_app - from agentevals.api.otlp_routes import set_trace_manager - mgr = get_trace_manager() - set_trace_manager(mgr) + mgr = getattr(app.state, "trace_manager", None) + if mgr: + otlp_app.state.trace_manager = mgr main_config = uvicorn.Config(app, host="127.0.0.1", port=main_port, log_level="warning") otlp_config = uvicorn.Config(otlp_app, host="127.0.0.1", port=otlp_port, log_level="warning") diff --git a/tests/test_api.py b/tests/test_api.py index 64d1eba..452d181 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -21,7 +21,6 @@ from fastapi.testclient import TestClient from agentevals.api.debug_routes import debug_router -from agentevals.api.debug_routes import set_trace_manager as set_debug_trace_manager from agentevals.api.models import ( CamelModel, CreateEvalSetData, @@ -31,12 +30,7 @@ StandardResponse, ) from agentevals.api.routes import _camel_keys, router -from agentevals.api.streaming_routes import ( - set_trace_manager as set_streaming_trace_manager, -) -from agentevals.api.streaming_routes import ( - streaming_router, -) +from agentevals.api.streaming_routes import streaming_router from agentevals.runner import MetricResult, RunResult, TraceResult from agentevals.streaming.session import TraceSession @@ -58,8 +52,7 @@ def _make_app() -> FastAPI: def _make_live_app(mgr) -> FastAPI: app = _make_app() app.include_router(streaming_router, prefix="/api/streaming") - set_streaming_trace_manager(mgr) - set_debug_trace_manager(mgr) + app.state.trace_manager = mgr return app @@ -691,7 +684,7 @@ def test_evaluate_sessions_missing_golden(self): assert resp.status_code == 404 @patch("agentevals.api.streaming_routes.run_evaluation", new_callable=AsyncMock) - @patch("agentevals.api.streaming_routes.create_eval_set_from_session", new_callable=AsyncMock) + @patch("agentevals.api.streaming_routes._do_create_eval_set", new_callable=AsyncMock) def test_evaluate_sessions_success(self, mock_create_eval, mock_eval): self.mgr.sessions.clear() self.mgr.sessions["golden"] = _make_session("golden", "tg") @@ -722,7 +715,7 @@ def test_evaluate_sessions_success(self, mock_create_eval, mock_eval): _assert_all_keys_camel(body) @patch("agentevals.api.streaming_routes.run_evaluation", new_callable=AsyncMock) - @patch("agentevals.api.streaming_routes.create_eval_set_from_session", new_callable=AsyncMock) + @patch("agentevals.api.streaming_routes._do_create_eval_set", new_callable=AsyncMock) def test_evaluate_sessions_eval_failure(self, mock_create_eval, mock_eval): self.mgr.sessions.clear() self.mgr.sessions["golden"] = _make_session("golden", "tg") @@ -774,7 +767,7 @@ def test_prepare_missing_golden(self): ) assert resp.status_code == 404 - @patch("agentevals.api.streaming_routes.create_eval_set_from_session", new_callable=AsyncMock) + @patch("agentevals.api.streaming_routes._do_create_eval_set", new_callable=AsyncMock) def test_prepare_success(self, mock_create_eval): self.mgr.sessions.clear() self.mgr.sessions["golden"] = _make_session("golden", "tg") @@ -802,7 +795,7 @@ def test_prepare_success(self, mock_create_eval): assert body["data"]["numTraces"] == 1 _assert_all_keys_camel(body) - @patch("agentevals.api.streaming_routes.create_eval_set_from_session", new_callable=AsyncMock) + @patch("agentevals.api.streaming_routes._do_create_eval_set", new_callable=AsyncMock) def test_prepare_skips_incomplete(self, mock_create_eval): self.mgr.sessions.clear() self.mgr.sessions["golden"] = _make_session("golden", "tg") @@ -968,13 +961,12 @@ def test_bundle_zip_contents(self): class TestDebugLoad: def test_load_no_live_mode(self): - set_debug_trace_manager(None) client = TestClient(_make_app()) resp = client.post( "/api/debug/load", files={"file": ("report.zip", io.BytesIO(b"fake"), "application/zip")}, ) - assert resp.status_code == 400 + assert resp.status_code == 503 def test_load_invalid_zip(self): mgr = _make_trace_manager() diff --git a/tests/test_otlp_receiver.py b/tests/test_otlp_receiver.py index 62055a2..67ac00c 100644 --- a/tests/test_otlp_receiver.py +++ b/tests/test_otlp_receiver.py @@ -14,7 +14,6 @@ _parse_otlp_body, _process_logs, _process_traces, - set_trace_manager, ) from agentevals.streaming.session import TraceSession from agentevals.streaming.ws_server import StreamingTraceManager @@ -475,7 +474,6 @@ class TestLateLogReextraction: def test_late_logs_accepted_for_completed_session(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) meta = {"eval_set_id": None, "session_name": "s1", "resource_attrs": {}} session = await mgr.get_or_create_otlp_session("trace-abc", meta) session.is_complete = True @@ -507,7 +505,7 @@ async def go(): } ] } - await _process_logs(body) + await _process_logs(body, mgr) assert len(session.logs) == 1 mgr.schedule_log_reextraction.assert_called_once_with("s1") @@ -520,7 +518,6 @@ def test_late_logs_not_matched_to_completed_session_by_name(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) meta = {"eval_set_id": None, "session_name": "named-session", "resource_attrs": {}} session = await mgr.get_or_create_otlp_session("trace-abc", meta) session.is_complete = True @@ -550,7 +547,7 @@ async def go(): } ] } - await _process_logs(body) + await _process_logs(body, mgr) assert len(session.logs) == 0 assert "new-trace-id" not in session.trace_ids @@ -759,14 +756,13 @@ class TestProcessTraces: def test_single_span_creates_session(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) body = _make_export_request( spans=[_make_span(trace_id="t1", parent_span_id="p1")], resource_attrs=[ _make_otlp_attr("agentevals.session_name", "test-session"), ], ) - await _process_traces(body) + await _process_traces(body, mgr) assert "test-session" in mgr.sessions session = mgr.sessions["test-session"] assert len(session.spans) == 1 @@ -778,14 +774,13 @@ async def go(): def test_multiple_spans_same_trace(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) body = _make_export_request( spans=[ _make_span(trace_id="t1", span_id="s1", parent_span_id="p1"), _make_span(trace_id="t1", span_id="s2", parent_span_id="p1"), ], ) - await _process_traces(body) + await _process_traces(body, mgr) sessions = [s for s in mgr.sessions.values() if s.trace_id == "t1"] assert len(sessions) == 1 assert len(sessions[0].spans) == 2 @@ -796,7 +791,6 @@ async def go(): def test_different_traces_create_different_sessions(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) body = { "resourceSpans": [ { @@ -823,7 +817,7 @@ async def go(): }, ] } - await _process_traces(body) + await _process_traces(body, mgr) trace_ids = {s.trace_id for s in mgr.sessions.values()} assert "t1" in trace_ids assert "t2" in trace_ids @@ -834,13 +828,12 @@ async def go(): def test_scope_injected_into_spans(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) body = _make_export_request( spans=[_make_span(trace_id="t1")], scope_name="gcp.vertex.agent", scope_version="1.2.3", ) - await _process_traces(body) + await _process_traces(body, mgr) session = list(mgr.sessions.values())[0] span = session.spans[0] attr_map = {a["key"]: a["value"] for a in span["attributes"]} @@ -856,11 +849,10 @@ def test_root_span_schedules_completion(self): async def go(): mgr = _make_mgr() mgr.schedule_session_completion = MagicMock() - set_trace_manager(mgr) body = _make_export_request( spans=[_make_span(trace_id="t1", parent_span_id=None)], ) - await _process_traces(body) + await _process_traces(body, mgr) session = list(mgr.sessions.values())[0] assert session.has_root_span is True mgr.schedule_session_completion.assert_called_once_with(session.session_id) @@ -872,14 +864,13 @@ def test_idle_timer_reset_on_each_span(self): async def go(): mgr = _make_mgr() mgr.reset_idle_timer = MagicMock() - set_trace_manager(mgr) body = _make_export_request( spans=[ _make_span(trace_id="t1", span_id="s1"), _make_span(trace_id="t1", span_id="s2"), ], ) - await _process_traces(body) + await _process_traces(body, mgr) assert mgr.reset_idle_timer.call_count == 2 _cancel_timers(mgr) @@ -890,7 +881,6 @@ def test_multi_trace_same_session(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) meta = _make_resource_attrs(session_name="my-session") body1 = _make_export_request( @@ -901,8 +891,8 @@ async def go(): spans=[_make_span(trace_id="trace-b", span_id="s2")], resource_attrs=meta, ) - await _process_traces(body1) - await _process_traces(body2) + await _process_traces(body1, mgr) + await _process_traces(body2, mgr) assert len(mgr.sessions) == 1 session = mgr.sessions["my-session"] @@ -917,7 +907,6 @@ def test_logs_route_to_multi_trace_session(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) meta = _make_resource_attrs(session_name="my-session") body1 = _make_export_request( @@ -928,8 +917,8 @@ async def go(): spans=[_make_span(trace_id="trace-b")], resource_attrs=meta, ) - await _process_traces(body1) - await _process_traces(body2) + await _process_traces(body1, mgr) + await _process_traces(body2, mgr) log_body = { "resourceLogs": [ @@ -960,7 +949,7 @@ async def go(): } ] } - await _process_logs(log_body) + await _process_logs(log_body, mgr) session = mgr.sessions["my-session"] assert len(session.logs) == 2 @@ -971,8 +960,7 @@ async def go(): def test_empty_request(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) - await _process_traces({"resourceSpans": []}) + await _process_traces({"resourceSpans": []}, mgr) assert len(mgr.sessions) == 0 _run(go()) @@ -980,11 +968,10 @@ async def go(): def test_broadcasts_span_received(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) body = _make_export_request( spans=[_make_span(trace_id="t1")], ) - await _process_traces(body) + await _process_traces(body, mgr) span_received_calls = [c for c in mgr.broadcast_to_ui.call_args_list if c[0][0]["type"] == "span_received"] assert len(span_received_calls) == 1 _cancel_timers(mgr) @@ -1005,7 +992,6 @@ def test_logs_buffered_when_no_session_exists(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) body = { "resourceLogs": [ { @@ -1037,7 +1023,7 @@ async def go(): } ] } - await _process_logs(body) + await _process_logs(body, mgr) assert len(mgr._orphan_logs) == 1 assert mgr._orphan_logs[0]["trace_id"] == "trace-1" assert mgr._orphan_logs[0]["session_name"] == "my-agent" @@ -1049,7 +1035,6 @@ def test_orphan_logs_replayed_on_session_creation(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) log_body = { "resourceLogs": [ @@ -1082,7 +1067,7 @@ async def go(): } ] } - await _process_logs(log_body) + await _process_logs(log_body, mgr) assert len(mgr._orphan_logs) == 1 meta = {"eval_set_id": None, "session_name": "my-agent", "resource_attrs": {}} @@ -1099,7 +1084,6 @@ def test_orphan_logs_matched_by_session_name(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) log_body = { "resourceLogs": [ @@ -1132,7 +1116,7 @@ async def go(): } ] } - await _process_logs(log_body) + await _process_logs(log_body, mgr) meta = {"eval_set_id": None, "session_name": "my-agent", "resource_attrs": {}} session = await mgr.get_or_create_otlp_session("trace-1", meta) @@ -1147,7 +1131,6 @@ def test_expired_orphan_logs_not_replayed(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) mgr._orphan_logs.append( { @@ -1175,7 +1158,6 @@ def test_multiple_orphan_logs_for_same_session(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) for i in range(3): log_body = { @@ -1209,7 +1191,7 @@ async def go(): } ] } - await _process_logs(log_body) + await _process_logs(log_body, mgr) assert len(mgr._orphan_logs) == 3 @@ -1226,7 +1208,6 @@ class TestProcessLogs: def test_routes_log_to_session_by_trace_id(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) meta = {"eval_set_id": None, "session_name": "s1", "resource_attrs": {}} await mgr.get_or_create_otlp_session("trace-abc", meta) body = { @@ -1250,7 +1231,7 @@ async def go(): } ] } - await _process_logs(body) + await _process_logs(body, mgr) session = mgr.sessions["s1"] assert len(session.logs) == 1 assert session.logs[0]["event_name"] == "gen_ai.user.message" @@ -1260,7 +1241,6 @@ async def go(): def test_buffers_log_with_unknown_trace_id(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) body = { "resourceLogs": [ { @@ -1282,7 +1262,7 @@ async def go(): } ] } - await _process_logs(body) + await _process_logs(body, mgr) assert len(mgr.sessions) == 0 assert len(mgr._orphan_logs) == 1 @@ -1291,7 +1271,6 @@ async def go(): def test_ignores_non_genai_logs(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) meta = {"eval_set_id": None, "session_name": "s1", "resource_attrs": {}} await mgr.get_or_create_otlp_session("trace-1", meta) body = { @@ -1315,7 +1294,7 @@ async def go(): } ] } - await _process_logs(body) + await _process_logs(body, mgr) session = mgr.sessions["s1"] assert len(session.logs) == 0 @@ -1595,7 +1574,6 @@ class TestProtobufJsonParity: def test_protobuf_traces_create_session(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) resource_attrs = [ KeyValue(key="agentevals.session_name", value=AnyValue(string_value="pb-session")), @@ -1606,7 +1584,7 @@ async def go(): raw = request.SerializeToString() body = _decode_protobuf_traces(raw) - await _process_traces(body) + await _process_traces(body, mgr) assert "pb-session" in mgr.sessions session = mgr.sessions["pb-session"] @@ -1621,14 +1599,13 @@ def test_protobuf_root_span_schedules_completion(self): async def go(): mgr = _make_mgr() mgr.schedule_session_completion = MagicMock() - set_trace_manager(mgr) span = _make_pb_span(parent_span_id_hex=None) request = _make_pb_export_request([span]) raw = request.SerializeToString() body = _decode_protobuf_traces(raw) - await _process_traces(body) + await _process_traces(body, mgr) session = list(mgr.sessions.values())[0] assert session.has_root_span is True @@ -1640,14 +1617,13 @@ async def go(): def test_protobuf_scope_injection(self): async def go(): mgr = _make_mgr() - set_trace_manager(mgr) span = _make_pb_span() request = _make_pb_export_request([span], scope_name="strands.agent", scope_version="2.0.0") raw = request.SerializeToString() body = _decode_protobuf_traces(raw) - await _process_traces(body) + await _process_traces(body, mgr) session = list(mgr.sessions.values())[0] stored_span = session.spans[0]