diff --git a/src/agentunit/adapters/swarm_adapter.py b/src/agentunit/adapters/swarm_adapter.py index 8eefa87..5d3a30c 100644 --- a/src/agentunit/adapters/swarm_adapter.py +++ b/src/agentunit/adapters/swarm_adapter.py @@ -13,6 +13,17 @@ from datetime import datetime from typing import TYPE_CHECKING, Any + +try: + from swarm import Agent, Swarm + + HAS_SWARM = True +except ImportError: + HAS_SWARM = False + logging.getLogger(__name__).warning( + "OpenAI Swarm not installed. SwarmAdapter will have limited functionality." + ) + from agentunit.multiagent import ( AgentInteraction, AgentMetadata, @@ -23,23 +34,14 @@ OrchestrationPattern, SessionID, ) -from agentunit.reporting.results import ScenarioResult if TYPE_CHECKING: from agentunit.core import Scenario + from agentunit.reporting.results import ScenarioResult -# Configure logging logger = logging.getLogger(__name__) -try: - from swarm import Agent, Swarm - - HAS_SWARM = True -except ImportError: - HAS_SWARM = False - logger.warning("OpenAI Swarm not installed. SwarmAdapter will have limited functionality.") - class SwarmAdapter(MultiAgentAdapter): """OpenAI Swarm integration adapter for lightweight multi-agent testing.""" @@ -149,7 +151,11 @@ def wrapped_function(*args, **kwargs): return wrapped_function def initiate_conversation( - self, scenario: Scenario, initial_message: str, participants: list[str], **kwargs + self, + scenario: Scenario, + initial_message: str, + participants: list[str], + **kwargs, ) -> SessionID: """Initiate a Swarm conversation.""" session_id = f"session_{uuid.uuid4().hex[:8]}" @@ -323,6 +329,14 @@ def end_session(self, session_id: SessionID) -> ScenarioResult: session["status"] = "completed" session["end_time"] = datetime.now() + # Calculate duration in milliseconds + start_time = session.get("start_time") + if start_time: + duration_seconds = (session["end_time"] - start_time).total_seconds() + duration_ms = float(duration_seconds * 1000) + else: + duration_ms = 0.0 + # Collect metrics metrics = self._calculate_swarm_metrics(session_id) @@ -330,27 +344,37 @@ def end_session(self, session_id: SessionID) -> ScenarioResult: from agentunit.core.trace import TraceLog trace = TraceLog() - trace.record("session_complete", session_id=session_id, metrics=metrics) + trace.record( + "session_complete", + session_id=session_id, + metrics=metrics, + ) # Create scenario run from agentunit.reporting.results import ScenarioRun scenario_run = ScenarioRun( - scenario_name=session["scenario"].name - if hasattr(session["scenario"], "name") - else "swarm_scenario", + scenario_name=( + session["scenario"].name + if hasattr(session["scenario"], "name") + else "swarm_scenario" + ), case_id=session_id, success=True, metrics=metrics, - duration_ms=0.0, # TODO: Track actual duration + duration_ms=duration_ms, trace=trace, ) # Create result + from agentunit.reporting.results import ScenarioResult + result = ScenarioResult( - name=session["scenario"].name - if hasattr(session["scenario"], "name") - else "swarm_scenario" + name=( + session["scenario"].name + if hasattr(session["scenario"], "name") + else "swarm_scenario" + ) ) result.add_run(scenario_run) @@ -414,7 +438,7 @@ def _analyze_handoff_patterns(self, handoffs: list[dict[str, Any]]) -> dict[str, return { "handoff_counts": patterns, "unique_patterns": len(patterns), - "most_common": max(patterns.items(), key=lambda x: x[1]) if patterns else None, + "most_common": (max(patterns.items(), key=lambda x: x[1]) if patterns else None), } def _calculate_agent_utilization(self, session_id: SessionID) -> dict[str, float]: diff --git a/src/agentunit/reporting/results.py b/src/agentunit/reporting/results.py index 8d94444..1c25cc1 100644 --- a/src/agentunit/reporting/results.py +++ b/src/agentunit/reporting/results.py @@ -200,9 +200,6 @@ def to_csv(self, path: str | Path) -> Path: rows.append(row) - if not rows: - return target - fieldnames = sorted({key for row in rows for key in row}) with target.open("w", newline="", encoding="utf-8") as f: