Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions src/agentevals/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from ..utils.log_buffer import log_buffer
from .debug_routes import debug_router
from .debug_routes import set_trace_manager as set_debug_trace_manager
from .routes import router

try:
Expand All @@ -42,11 +41,12 @@ async def lifespan(app: FastAPI):
if log_buffer not in ae_logger.handlers:
log_buffer.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
ae_logger.addHandler(log_buffer)
if _trace_manager:
_trace_manager.start_cleanup_task()
mgr = getattr(app.state, "trace_manager", None)
if mgr:
mgr.start_cleanup_task()
yield
if _trace_manager:
await _trace_manager.shutdown()
if mgr:
await mgr.shutdown()
ae_logger.removeHandler(log_buffer)


Expand All @@ -70,27 +70,27 @@ async def lifespan(app: FastAPI):
app.include_router(debug_router, prefix="/api/debug")

_live_mode = os.getenv("AGENTEVALS_LIVE") == "1"
_trace_manager = None

if _live_mode:
from fastapi import Request as _Request
from fastapi import WebSocket

from ..streaming.ws_server import StreamingTraceManager
from .streaming_routes import set_trace_manager, streaming_router
from .streaming_routes import streaming_router

app.include_router(streaming_router, prefix="/api/streaming")
_trace_manager = StreamingTraceManager()
set_trace_manager(_trace_manager)
set_debug_trace_manager(_trace_manager)
app.state.trace_manager = StreamingTraceManager()

@app.websocket("/ws/traces")
async def websocket_endpoint(websocket: WebSocket):
await _trace_manager.handle_connection(websocket)
await websocket.app.state.trace_manager.handle_connection(websocket)

@app.get("/stream/ui-updates")
async def ui_updates_stream():
async def ui_updates_stream(request: _Request):
mgr = request.app.state.trace_manager

async def event_generator():
queue = _trace_manager.register_sse_client()
queue = mgr.register_sse_client()
try:
while True:
event = await queue.get()
Expand All @@ -100,7 +100,7 @@ async def event_generator():
except asyncio.CancelledError:
pass
finally:
_trace_manager.unregister_sse_client(queue)
mgr.unregister_sse_client(queue)

return StreamingResponse(
event_generator(),
Expand All @@ -112,10 +112,6 @@ async def event_generator():
)


def get_trace_manager():
return _trace_manager


_static_dir = Path(__file__).parent.parent / "_static"
_has_ui = _static_dir.is_dir() and (_static_dir / "index.html").exists()

Expand Down
44 changes: 19 additions & 25 deletions src/agentevals/api/debug_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from datetime import UTC, datetime
from typing import TYPE_CHECKING

from fastapi import APIRouter, HTTPException, UploadFile
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi import File as FastAPIFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from agentevals import __version__

from ..utils.log_buffer import log_buffer
from .dependencies import get_trace_manager, require_trace_manager
from .models import DebugLoadData, SessionInfo, StandardResponse, WSSessionCompleteEvent, WSSessionStartedEvent

if TYPE_CHECKING:
Expand All @@ -30,13 +31,6 @@

debug_router = APIRouter()

_trace_manager: StreamingTraceManager | None = None


def set_trace_manager(manager: StreamingTraceManager) -> None:
global _trace_manager
_trace_manager = manager


class FrontendDiagnostics(BaseModel):
user_description: str = ""
Expand Down Expand Up @@ -83,12 +77,12 @@ def _collect_environment() -> dict:
}


def _collect_sessions() -> list[dict]:
if not _trace_manager:
def _collect_sessions(manager: StreamingTraceManager | None) -> list[dict]:
if not manager:
return []

sessions_data = []
for session in _trace_manager.sessions.values():
for session in manager.sessions.values():
sessions_data.append(
{
"session_id": session.session_id,
Expand Down Expand Up @@ -128,7 +122,10 @@ def _collect_temp_files(session_ids: set[str] | None = None) -> dict[str, str]:


@debug_router.post("/bundle")
async def create_debug_bundle(diagnostics: FrontendDiagnostics):
async def create_debug_bundle(
diagnostics: FrontendDiagnostics,
manager: StreamingTraceManager | None = Depends(get_trace_manager),
):
timestamp = datetime.now(tz=UTC).strftime("%Y%m%d-%H%M%S")
prefix = f"bug-report-{timestamp}"

Expand All @@ -142,7 +139,7 @@ async def create_debug_bundle(diagnostics: FrontendDiagnostics):
}
zf.writestr(f"{prefix}/metadata.json", json.dumps(metadata, indent=2))

sessions = _collect_sessions()
sessions = _collect_sessions(manager)
for s in sessions:
sid = s["session_id"]
zf.writestr(
Expand Down Expand Up @@ -188,13 +185,10 @@ async def create_debug_bundle(diagnostics: FrontendDiagnostics):


@debug_router.post("/load", response_model=StandardResponse[DebugLoadData])
async def load_debug_bundle(file: UploadFile = FastAPIFile(...)):
if not _trace_manager:
raise HTTPException(
status_code=400,
detail="Live mode is not enabled. Start with: agentevals serve --dev",
)

async def load_debug_bundle(
file: UploadFile = FastAPIFile(...),
manager: StreamingTraceManager = Depends(require_trace_manager),
):
content = await file.read()
try:
zf = zipfile.ZipFile(io.BytesIO(content))
Expand Down Expand Up @@ -236,9 +230,9 @@ async def load_debug_bundle(file: UploadFile = FastAPIFile(...)):
metadata=meta.get("metadata", {}),
)

_trace_manager.sessions[session.session_id] = session
manager.sessions[session.session_id] = session

await _trace_manager.broadcast_to_ui(
await manager.broadcast_to_ui(
WSSessionStartedEvent(
session=SessionInfo(
session_id=session.session_id,
Expand All @@ -252,10 +246,10 @@ async def load_debug_bundle(file: UploadFile = FastAPIFile(...)):
).model_dump(by_alias=True)
)

invocations_data = await _trace_manager._extract_invocations(session)
await _trace_manager._save_spans_to_temp_file(session)
invocations_data = await manager._extract_invocations(session)
await manager._save_spans_to_temp_file(session)

await _trace_manager.broadcast_to_ui(
await manager.broadcast_to_ui(
WSSessionCompleteEvent(
session_id=session.session_id,
invocations=invocations_data,
Expand Down
23 changes: 23 additions & 0 deletions src/agentevals/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""FastAPI dependency functions for shared services."""

from __future__ import annotations

from typing import TYPE_CHECKING

from fastapi import HTTPException, Request

if TYPE_CHECKING:
from ..streaming.ws_server import StreamingTraceManager


def get_trace_manager(request: Request) -> StreamingTraceManager | None:
"""Return the StreamingTraceManager or None if live mode is off."""
return getattr(request.app.state, "trace_manager", None)


def require_trace_manager(request: Request) -> StreamingTraceManager:
"""Return the StreamingTraceManager, raising 503 if live mode is off."""
mgr = getattr(request.app.state, "trace_manager", None)
if mgr is None:
raise HTTPException(status_code=503, detail="Live mode not enabled")
return mgr
8 changes: 4 additions & 4 deletions src/agentevals/api/otlp_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

from fastapi import FastAPI

from .otlp_routes import otlp_router, set_trace_manager
from .otlp_routes import otlp_router


@asynccontextmanager
async def lifespan(app: FastAPI):
from .app import get_trace_manager
from .app import app as main_app

mgr = get_trace_manager()
mgr = getattr(main_app.state, "trace_manager", None)
if mgr:
set_trace_manager(mgr)
app.state.trace_manager = mgr
yield


Expand Down
Loading
Loading