From c2db1354434c8ac548f0d532c6f79562224cbc29 Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Mon, 26 May 2025 12:34:06 -0400 Subject: [PATCH 1/2] fix(telemetry): fix agent span start and end when using Agent.stream_async() --- src/strands/agent/agent.py | 67 ++++++++++++++++++------ tests/strands/agent/test_agent.py | 87 +++++++++++++++++++++++++++++-- 2 files changed, 135 insertions(+), 19 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bed9e52e6..79fa91e57 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None - - self.trace_span = self.tracer.start_agent_span( - prompt=prompt, - model_id=model_id, - tools=self.tool_names, - system_prompt=self.system_prompt, - custom_trace_attributes=self.trace_attributes, - ) + self._start_agent_trace_span(prompt) try: # Run the event loop and get the result result = self._run_loop(prompt, kwargs) - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, response=result) + self._end_agent_trace_span(span=self.trace_span, response=result) return result except Exception as e: - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, error=e) + self._end_agent_trace_span(span=self.trace_span, error=e) # Re-raise the exception to preserve original behavior raise @@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: yield event["data"] ``` """ + self._start_agent_trace_span(prompt) + _stop_event = uuid4() queue = asyncio.Queue[Any]() @@ -400,8 +392,10 @@ def target_callback() -> None: nonlocal kwargs try: - self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) - except BaseException as e: + result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) + self._end_agent_trace_span(span=self.trace_span, response=result) + except Exception as e: + self._end_agent_trace_span(span=self.trace_span, error=e) enqueue(e) finally: enqueue(_stop_event) @@ -414,7 +408,7 @@ def target_callback() -> None: item = await queue.get() if item == _stop_event: break - if isinstance(item, BaseException): + if isinstance(item, Exception): raise item yield item finally: @@ -546,3 +540,44 @@ def _record_tool_execution( messages.append(tool_use_msg) messages.append(tool_result_msg) messages.append(assistant_msg) + + def _start_agent_trace_span(self, prompt: str) -> None: + """Starts a trace span for the agent. + + Args: + prompt: The natural language prompt from the user. + """ + model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None + + self.trace_span = self.tracer.start_agent_span( + prompt=prompt, + model_id=model_id, + tools=self.tool_names, + system_prompt=self.system_prompt, + custom_trace_attributes=self.trace_attributes, + ) + + def _end_agent_trace_span( + self, + span: Optional[trace.Span] = None, + response: Optional[AgentResult] = None, + error: Optional[Exception] = None, + ) -> None: + """Ends a trace span for the agent. + + Args: + span: The span to end. + response: Response to record as a trace attribute. + error: Error to record as a trace attribute. + """ + if self.trace_span: + trace_attributes: Dict[str, Any] = { + "span": self.trace_span, + } + + if response: + trace_attributes["response"] = response + if error: + trace_attributes["error"] = error + + self.tracer.end_agent_span(**trace_attributes) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ff70089bd..828ae8f96 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -9,7 +9,8 @@ import pytest import strands -from strands.agent.agent import Agent +from strands import Agent +from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler @@ -687,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler(): @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): - mock_event_loop_cycle.side_effect = ValueError("Test exception") - agent = Agent() # Define the side effect to simulate callback handler being called multiple times @@ -952,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler being called multiple times + def call_callback_handler(*args, **kwargs): + # Extract the callback handler from kwargs + callback_handler = kwargs.get("callback_handler") + # Call the callback handler with different data values + callback_handler(data="First chunk") + callback_handler(data="Second chunk") + callback_handler(data="Final chunk", complete=True) + # Return expected values from event_loop_cycle + return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {} + + mock_event_loop_cycle.side_effect = call_callback_handler + + # Create agent and make a call + agent = Agent(model=mock_model) + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + expected_response = AgentResult( + stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={} + ) + + # Verify span was ended with the result + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): """Test that __call__ creates and ends a span when an exception occurs.""" @@ -985,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler raising an Exception + test_exception = ValueError("Test exception") + mock_model.mock_converse.side_effect = test_exception + + # Create agent and make a call + agent = Agent(model=mock_model) + + # Call the agent and catch the exception + with pytest.raises(ValueError): + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + # Verify span was ended with the exception + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model): """Test that event_loop_cycle is called with the parent span.""" From 666d0ab1304ef19c4db3f27911c2daf3ec6fd182 Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Mon, 26 May 2025 13:01:58 -0400 Subject: [PATCH 2/2] chore(agent): remove unused argument in _end_agent_trace_span --- src/strands/agent/agent.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 79fa91e57..0f912b54b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -334,11 +334,11 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: # Run the event loop and get the result result = self._run_loop(prompt, kwargs) - self._end_agent_trace_span(span=self.trace_span, response=result) + self._end_agent_trace_span(response=result) return result except Exception as e: - self._end_agent_trace_span(span=self.trace_span, error=e) + self._end_agent_trace_span(error=e) # Re-raise the exception to preserve original behavior raise @@ -393,9 +393,9 @@ def target_callback() -> None: try: result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) - self._end_agent_trace_span(span=self.trace_span, response=result) + self._end_agent_trace_span(response=result) except Exception as e: - self._end_agent_trace_span(span=self.trace_span, error=e) + self._end_agent_trace_span(error=e) enqueue(e) finally: enqueue(_stop_event) @@ -559,7 +559,6 @@ def _start_agent_trace_span(self, prompt: str) -> None: def _end_agent_trace_span( self, - span: Optional[trace.Span] = None, response: Optional[AgentResult] = None, error: Optional[Exception] = None, ) -> None: