From fd7372e38f8476e41b97349964c7e42e178bf0a1 Mon Sep 17 00:00:00 2001 From: Ankit Singhal Date: Wed, 1 Apr 2026 12:23:04 -0700 Subject: [PATCH] Fix invoke_agent span parenting so framework child spans nest correctly Add set_current_span/detach_context to TracingHelper so the invoke_agent span is attached as the active OTel context span. This ensures child spans created by framework handlers are parented under invoke_agent instead of appearing as siblings. - TracingHelper.set_current_span() attaches the span to context - TracingHelper.detach_context() restores the previous context - InvocationHandler propagates span_token through both streaming and non-streaming cleanup paths - Added test_span_parenting.py to verify parent-child relationship Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../azure/ai/agentserver/core/_tracing.py | 37 ++++- .../ai/agentserver/invocations/_invocation.py | 27 ++-- .../tests/test_span_parenting.py | 129 ++++++++++++++++++ 3 files changed, 182 insertions(+), 11 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py index eefd4df7b2be..3e79af4ed93c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_tracing.py @@ -65,12 +65,12 @@ _HAS_OTEL = False _HAS_BAGGAGE = False try: - from opentelemetry import trace + from opentelemetry import context as _otel_context, trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator _HAS_OTEL = True try: - from opentelemetry import baggage as _otel_baggage, context as _otel_context + from opentelemetry import baggage as _otel_baggage _HAS_BAGGAGE = True except ImportError: @@ -561,6 +561,39 @@ def detach_baggage(token: Any) -> None: if token is not None and _HAS_BAGGAGE: _otel_context.detach(token) + @staticmethod + def set_current_span(span: Any) -> Any: + """Set a span as the current span in the OTel context. + + This makes *span* the active parent for any child spans created + by downstream code (e.g. framework handlers). Without this, + spans created inside the handler would become siblings rather + than children of *span*. + + Returns a context token that **must** be passed to + :meth:`detach_context` when the scope ends. No-op when *span* + is ``None`` or tracing is not available. + + :param span: The OTel span to make current, or *None*. + :type span: Any + :return: A context token, or *None*. + :rtype: Any + """ + if span is None or not _HAS_OTEL: + return None + ctx = trace.set_span_in_context(span) + return _otel_context.attach(ctx) + + @staticmethod + def detach_context(token: Any) -> None: + """Detach a context previously attached by :meth:`set_current_span`. + + :param token: The token returned by :meth:`set_current_span`. + :type token: Any + """ + if token is not None and _HAS_OTEL: + _otel_context.detach(token) + async def trace_stream( self, iterator: AsyncIterable[_Content], span: Any ) -> AsyncIterator[_Content]: diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py index a89b88188a6a..2c2a56554615 100644 --- a/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py +++ b/sdk/agentserver/azure-ai-agentserver-invocations/azure/ai/agentserver/invocations/_invocation.py @@ -259,22 +259,24 @@ def _wrap_streaming_response( response: StreamingResponse, otel_span: Any, baggage_token: Any, + span_token: Any, ) -> StreamingResponse: - """Wrap a streaming response's body iterator with tracing and baggage cleanup. + """Wrap a streaming response's body iterator with tracing and context cleanup. Two layers of wrapping are applied in order: 1. **Inner (tracing):** ``trace_stream`` wraps the body iterator so the OTel span covers the full streaming duration and records any errors that occur while yielding chunks. - 2. **Outer (baggage cleanup):** A second async generator detaches the - W3C Baggage context *after* all chunks have been sent (or an - error occurs). This ordering ensures the span is ended before - the baggage context is detached. + 2. **Outer (context cleanup):** A second async generator detaches the + span context and W3C Baggage context *after* all chunks have been + sent (or an error occurs). This ordering ensures the span is + ended before the contexts are detached. :param response: The ``StreamingResponse`` returned by the user handler. :param otel_span: The OTel span (or *None* when tracing is disabled). :param baggage_token: Token from ``set_baggage`` (or *None*). + :param span_token: Token from ``set_current_span`` (or *None*). :return: The same response object, with its body_iterator replaced. """ # When tracing is disabled there is nothing to wrap — skip the @@ -286,7 +288,7 @@ def _wrap_streaming_response( # Inner wrap: trace_stream ends the span when iteration completes. response.body_iterator = self._tracing.trace_stream(response.body_iterator, otel_span) - # Outer wrap: detach baggage after all chunks are sent. + # Outer wrap: detach span context and baggage after all chunks are sent. original_iterator = response.body_iterator tracing = self._tracing # capture for the closure @@ -295,6 +297,7 @@ async def _cleanup_iter(): # type: ignore[return-value] async for chunk in original_iterator: yield chunk finally: + tracing.detach_context(span_token) tracing.detach_baggage(baggage_token) response.body_iterator = _cleanup_iter() @@ -326,6 +329,7 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: request.state.session_id = session_id baggage_token = None + span_token = None response: Optional[Response] = None streaming_wrapped = False @@ -339,6 +343,10 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: operation_name="invoke_agent", session_id=session_id, ) + # Make the span the current span in context so that + # child spans created by framework handlers are correctly + # parented under this span instead of appearing as siblings. + span_token = self._tracing.set_current_span(otel_span) self._safe_set_attrs(otel_span, { InvocationConstants.ATTR_SPAN_INVOCATION_ID: invocation_id, InvocationConstants.ATTR_SPAN_SESSION_ID: session_id, @@ -388,7 +396,7 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: ) if isinstance(response, StreamingResponse): - wrapped = self._wrap_streaming_response(response, otel_span, baggage_token) + wrapped = self._wrap_streaming_response(response, otel_span, baggage_token, span_token) streaming_wrapped = True return wrapped @@ -399,11 +407,12 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: return response finally: # For non-streaming responses (or error paths that returned - # before reaching _wrap_streaming_response), detach baggage - # immediately. Streaming responses handle this in + # before reaching _wrap_streaming_response), detach context + # and baggage immediately. Streaming responses handle this in # _wrap_streaming_response's cleanup iterator instead. if not streaming_wrapped: if self._tracing is not None: + self._tracing.detach_context(span_token) self._tracing.detach_baggage(baggage_token) async def _traced_invocation_endpoint( diff --git a/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py new file mode 100644 index 000000000000..d874db2b741f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-invocations/tests/test_span_parenting.py @@ -0,0 +1,129 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests that the invoke_agent span is set as the current span in context, +so that child spans created by framework handlers are correctly parented.""" +import os +from unittest.mock import patch + +import pytest +from httpx import ASGITransport, AsyncClient +from starlette.requests import Request +from starlette.responses import Response, StreamingResponse + +from azure.ai.agentserver.core import AgentHost +from azure.ai.agentserver.invocations import InvocationHandler + +try: + from opentelemetry import trace + from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + _HAS_OTEL = True +except ImportError: + _HAS_OTEL = False + +if _HAS_OTEL: + _EXPORTER = InMemorySpanExporter() + _PROVIDER = SdkTracerProvider() + _PROVIDER.add_span_processor(SimpleSpanProcessor(_EXPORTER)) + trace.set_tracer_provider(_PROVIDER) +else: + _EXPORTER = None + +pytestmark = pytest.mark.skipif(not _HAS_OTEL, reason="opentelemetry not installed") + + +@pytest.fixture(autouse=True) +def _clear(): + if _EXPORTER: + _EXPORTER.clear() + + +def _get_spans(): + return list(_EXPORTER.get_finished_spans()) if _EXPORTER else [] + + +def _make_server_with_child_span(): + """Server whose handler creates a child span (simulating a framework).""" + with patch("azure.ai.agentserver.core._tracing.TracingHelper._setup_azure_monitor"): + server = AgentHost() + invocations = InvocationHandler(server) + child_tracer = trace.get_tracer("test.framework") + + @invocations.invoke_handler + async def handle(request: Request) -> Response: + # Simulate a framework creating its own invoke_agent span + with child_tracer.start_as_current_span("framework_invoke_agent") as _span: + return Response(content=b"ok") + + return server + + +def _make_streaming_server_with_child_span(): + """Server with streaming response whose handler creates a child span.""" + with patch("azure.ai.agentserver.core._tracing.TracingHelper._setup_azure_monitor"): + server = AgentHost() + invocations = InvocationHandler(server) + child_tracer = trace.get_tracer("test.framework") + + @invocations.invoke_handler + async def handle(request: Request) -> StreamingResponse: + with child_tracer.start_as_current_span("framework_invoke_agent"): + async def generate(): + yield b"chunk\n" + return StreamingResponse(generate(), media_type="text/plain") + + return server + + +@pytest.mark.asyncio +async def test_framework_span_is_child_of_invoke_span(): + """A span created inside the handler should be a child of the + agentserver invoke_agent span, not a sibling.""" + server = _make_server_with_child_span() + transport = ASGITransport(app=server.app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b"test") + + spans = _get_spans() + parent_spans = [s for s in spans if "invoke_agent" in s.name] + child_spans = [s for s in spans if s.name == "framework_invoke_agent"] + + assert len(parent_spans) >= 1, f"Expected invoke_agent span, got: {[s.name for s in spans]}" + assert len(child_spans) == 1, f"Expected framework span, got: {[s.name for s in spans]}" + + parent = parent_spans[0] + child = child_spans[0] + + # The child span's parent should be the agentserver invoke_agent span + assert child.parent is not None, "Framework span has no parent — it's a root span (sibling)" + assert child.parent.span_id == parent.context.span_id, ( + f"Framework span parent ({format(child.parent.span_id, '016x')}) " + f"!= invoke_agent span ({format(parent.context.span_id, '016x')}). " + "Spans are siblings, not parent-child." + ) + + +@pytest.mark.asyncio +async def test_framework_span_is_child_streaming(): + """Same parent-child relationship holds for streaming responses.""" + server = _make_streaming_server_with_child_span() + transport = ASGITransport(app=server.app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"test") + assert resp.status_code == 200 + + spans = _get_spans() + parent_spans = [s for s in spans if "invoke_agent" in s.name] + child_spans = [s for s in spans if s.name == "framework_invoke_agent"] + + assert len(parent_spans) >= 1 + assert len(child_spans) == 1 + + parent = parent_spans[0] + child = child_spans[0] + + assert child.parent is not None, "Framework span has no parent in streaming case" + assert child.parent.span_id == parent.context.span_id