Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@
_HAS_OTEL = False
_HAS_BAGGAGE = False
try:
from opentelemetry import trace
from opentelemetry import context as _otel_context, trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

_HAS_OTEL = True
try:
from opentelemetry import baggage as _otel_baggage, context as _otel_context
from opentelemetry import baggage as _otel_baggage

_HAS_BAGGAGE = True
except ImportError:
Expand Down Expand Up @@ -561,6 +561,39 @@ def detach_baggage(token: Any) -> None:
if token is not None and _HAS_BAGGAGE:
_otel_context.detach(token)

@staticmethod
def set_current_span(span: Any) -> Any:
"""Set a span as the current span in the OTel context.

This makes *span* the active parent for any child spans created
by downstream code (e.g. framework handlers). Without this,
spans created inside the handler would become siblings rather
than children of *span*.

Returns a context token that **must** be passed to
:meth:`detach_context` when the scope ends. No-op when *span*
is ``None`` or tracing is not available.

:param span: The OTel span to make current, or *None*.
:type span: Any
:return: A context token, or *None*.
:rtype: Any
"""
if span is None or not _HAS_OTEL:
return None
ctx = trace.set_span_in_context(span)
return _otel_context.attach(ctx)

@staticmethod
def detach_context(token: Any) -> None:
"""Detach a context previously attached by :meth:`set_current_span`.

:param token: The token returned by :meth:`set_current_span`.
:type token: Any
"""
if token is not None and _HAS_OTEL:
_otel_context.detach(token)

async def trace_stream(
self, iterator: AsyncIterable[_Content], span: Any
) -> AsyncIterator[_Content]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,22 +259,24 @@ def _wrap_streaming_response(
response: StreamingResponse,
otel_span: Any,
baggage_token: Any,
span_token: Any,
) -> StreamingResponse:
"""Wrap a streaming response's body iterator with tracing and baggage cleanup.
"""Wrap a streaming response's body iterator with tracing and context cleanup.

Two layers of wrapping are applied in order:

1. **Inner (tracing):** ``trace_stream`` wraps the body iterator so
the OTel span covers the full streaming duration and records any
errors that occur while yielding chunks.
2. **Outer (baggage cleanup):** A second async generator detaches the
W3C Baggage context *after* all chunks have been sent (or an
error occurs). This ordering ensures the span is ended before
the baggage context is detached.
2. **Outer (context cleanup):** A second async generator detaches the
span context and W3C Baggage context *after* all chunks have been
sent (or an error occurs). This ordering ensures the span is
ended before the contexts are detached.

:param response: The ``StreamingResponse`` returned by the user handler.
:param otel_span: The OTel span (or *None* when tracing is disabled).
:param baggage_token: Token from ``set_baggage`` (or *None*).
:param span_token: Token from ``set_current_span`` (or *None*).
:return: The same response object, with its body_iterator replaced.
"""
# When tracing is disabled there is nothing to wrap — skip the
Expand All @@ -286,7 +288,7 @@ def _wrap_streaming_response(
# Inner wrap: trace_stream ends the span when iteration completes.
response.body_iterator = self._tracing.trace_stream(response.body_iterator, otel_span)

# Outer wrap: detach baggage after all chunks are sent.
# Outer wrap: detach span context and baggage after all chunks are sent.
original_iterator = response.body_iterator
tracing = self._tracing # capture for the closure

Expand All @@ -295,6 +297,7 @@ async def _cleanup_iter(): # type: ignore[return-value]
async for chunk in original_iterator:
yield chunk
finally:
tracing.detach_context(span_token)
tracing.detach_baggage(baggage_token)

response.body_iterator = _cleanup_iter()
Expand Down Expand Up @@ -326,6 +329,7 @@ async def _create_invocation_endpoint(self, request: Request) -> Response:
request.state.session_id = session_id

baggage_token = None
span_token = None
response: Optional[Response] = None
streaming_wrapped = False

Expand All @@ -339,6 +343,10 @@ async def _create_invocation_endpoint(self, request: Request) -> Response:
operation_name="invoke_agent",
session_id=session_id,
)
# Make the span the current span in context so that
# child spans created by framework handlers are correctly
# parented under this span instead of appearing as siblings.
span_token = self._tracing.set_current_span(otel_span)
self._safe_set_attrs(otel_span, {
InvocationConstants.ATTR_SPAN_INVOCATION_ID: invocation_id,
InvocationConstants.ATTR_SPAN_SESSION_ID: session_id,
Expand Down Expand Up @@ -388,7 +396,7 @@ async def _create_invocation_endpoint(self, request: Request) -> Response:
)

if isinstance(response, StreamingResponse):
wrapped = self._wrap_streaming_response(response, otel_span, baggage_token)
wrapped = self._wrap_streaming_response(response, otel_span, baggage_token, span_token)
streaming_wrapped = True
return wrapped

Expand All @@ -399,11 +407,12 @@ async def _create_invocation_endpoint(self, request: Request) -> Response:
return response
finally:
# For non-streaming responses (or error paths that returned
# before reaching _wrap_streaming_response), detach baggage
# immediately. Streaming responses handle this in
# before reaching _wrap_streaming_response), detach context
# and baggage immediately. Streaming responses handle this in
# _wrap_streaming_response's cleanup iterator instead.
if not streaming_wrapped:
if self._tracing is not None:
self._tracing.detach_context(span_token)
self._tracing.detach_baggage(baggage_token)

async def _traced_invocation_endpoint(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
"""Tests that the invoke_agent span is set as the current span in context,
so that child spans created by framework handlers are correctly parented."""
import os
from unittest.mock import patch

import pytest
from httpx import ASGITransport, AsyncClient
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse

from azure.ai.agentserver.core import AgentHost
from azure.ai.agentserver.invocations import InvocationHandler

try:
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

_HAS_OTEL = True
except ImportError:
_HAS_OTEL = False

if _HAS_OTEL:
_EXPORTER = InMemorySpanExporter()
_PROVIDER = SdkTracerProvider()
_PROVIDER.add_span_processor(SimpleSpanProcessor(_EXPORTER))
trace.set_tracer_provider(_PROVIDER)
else:
_EXPORTER = None

pytestmark = pytest.mark.skipif(not _HAS_OTEL, reason="opentelemetry not installed")


@pytest.fixture(autouse=True)
def _clear():
if _EXPORTER:
_EXPORTER.clear()


def _get_spans():
return list(_EXPORTER.get_finished_spans()) if _EXPORTER else []


def _make_server_with_child_span():
"""Server whose handler creates a child span (simulating a framework)."""
with patch("azure.ai.agentserver.core._tracing.TracingHelper._setup_azure_monitor"):
server = AgentHost()
invocations = InvocationHandler(server)
child_tracer = trace.get_tracer("test.framework")

@invocations.invoke_handler
async def handle(request: Request) -> Response:
# Simulate a framework creating its own invoke_agent span
with child_tracer.start_as_current_span("framework_invoke_agent") as _span:
return Response(content=b"ok")

return server


def _make_streaming_server_with_child_span():
"""Server with streaming response whose handler creates a child span."""
with patch("azure.ai.agentserver.core._tracing.TracingHelper._setup_azure_monitor"):
server = AgentHost()
invocations = InvocationHandler(server)
child_tracer = trace.get_tracer("test.framework")

@invocations.invoke_handler
async def handle(request: Request) -> StreamingResponse:
with child_tracer.start_as_current_span("framework_invoke_agent"):
async def generate():
yield b"chunk\n"
return StreamingResponse(generate(), media_type="text/plain")

return server


@pytest.mark.asyncio
async def test_framework_span_is_child_of_invoke_span():
"""A span created inside the handler should be a child of the
agentserver invoke_agent span, not a sibling."""
server = _make_server_with_child_span()
transport = ASGITransport(app=server.app)
async with AsyncClient(transport=transport, base_url="http://testserver") as client:
await client.post("/invocations", content=b"test")

spans = _get_spans()
parent_spans = [s for s in spans if "invoke_agent" in s.name]
child_spans = [s for s in spans if s.name == "framework_invoke_agent"]

assert len(parent_spans) >= 1, f"Expected invoke_agent span, got: {[s.name for s in spans]}"
assert len(child_spans) == 1, f"Expected framework span, got: {[s.name for s in spans]}"

parent = parent_spans[0]
child = child_spans[0]

# The child span's parent should be the agentserver invoke_agent span
assert child.parent is not None, "Framework span has no parent — it's a root span (sibling)"
assert child.parent.span_id == parent.context.span_id, (
f"Framework span parent ({format(child.parent.span_id, '016x')}) "
f"!= invoke_agent span ({format(parent.context.span_id, '016x')}). "
"Spans are siblings, not parent-child."
)


@pytest.mark.asyncio
async def test_framework_span_is_child_streaming():
"""Same parent-child relationship holds for streaming responses."""
server = _make_streaming_server_with_child_span()
transport = ASGITransport(app=server.app)
async with AsyncClient(transport=transport, base_url="http://testserver") as client:
resp = await client.post("/invocations", content=b"test")
assert resp.status_code == 200

spans = _get_spans()
parent_spans = [s for s in spans if "invoke_agent" in s.name]
child_spans = [s for s in spans if s.name == "framework_invoke_agent"]

assert len(parent_spans) >= 1
assert len(child_spans) == 1

parent = parent_spans[0]
child = child_spans[0]

assert child.parent is not None, "Framework span has no parent in streaming case"
assert child.parent.span_id == parent.context.span_id