From cd9e32b9878fda55cd0245aee004792d9decba0d Mon Sep 17 00:00:00 2001 From: "Nikhil Chitlur Navakiran (from Dev Box)" Date: Fri, 24 Apr 2026 01:48:54 +0530 Subject: [PATCH] update mapper for open ai --- .../extensions/openai/message_mapper.py | 381 ++++++++++++++++++ .../extensions/openai/trace_instrumentor.py | 2 +- .../extensions/openai/trace_processor.py | 25 ++ .../openai/integration/test_message_format.py | 284 +++++++++++++ .../extensions/openai/test_message_mapper.py | 261 ++++++++++++ 5 files changed, 952 insertions(+), 1 deletion(-) create mode 100644 libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/message_mapper.py create mode 100644 tests/observability/extensions/openai/integration/test_message_format.py create mode 100644 tests/observability/extensions/openai/test_message_mapper.py diff --git a/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/message_mapper.py b/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/message_mapper.py new file mode 100644 index 00000000..6cd5ea49 --- /dev/null +++ b/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/message_mapper.py @@ -0,0 +1,381 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Maps OpenAI span tag messages to A365 versioned message format. + +Handles three input shapes produced by the OpenAI trace processor: + +1. **Chat-completions format** (from ``GenerationSpanData``): + ``[{"role":"system","content":"..."}, ...]`` +2. **Response API format** (from ``ResponseSpanData``): + - Input: ``[{"type":"message","role":"user","content":"..."}, ...]`` + - Output: ``{"id":"...","model":"...","output":[...], ...}`` (full Response JSON) +3. **Plain string** (from ``AgentSpanData``): + A bare user/assistant message captured from child generation spans. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping +from typing import Any + +from microsoft_agents_a365.observability.core.message_utils import serialize_messages +from microsoft_agents_a365.observability.core.models.messages import ( + ChatMessage, + InputMessages, + MessagePart, + MessageRole, + OutputMessage, + OutputMessages, + TextPart, + ToolCallRequestPart, + ToolCallResponsePart, +) + +logger = logging.getLogger(__name__) + +_ROLE_MAP: dict[str, MessageRole] = { + "system": MessageRole.SYSTEM, + "user": MessageRole.USER, + "assistant": MessageRole.ASSISTANT, + "tool": MessageRole.TOOL, +} + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def map_input_messages(messages_json: str) -> str | None: + """Map a ``gen_ai.input.messages`` tag value to a serialized A365 JSON string. + + Args: + messages_json: The raw JSON string from the span attribute. + + Returns: + Serialized :class:`InputMessages` JSON string, or ``None`` if the + input is empty or cannot be parsed. + """ + if not messages_json: + return None + + # Plain string (AgentSpanData captures bare user text) + try: + raw = json.loads(messages_json) + except (json.JSONDecodeError, TypeError): + return _wrap_plain_input(messages_json) + + if isinstance(raw, list): + return _map_input_list(raw) + + # Unexpected shape + return _wrap_plain_input(messages_json) + + +def map_output_messages(messages_json: str) -> str | None: + """Map a ``gen_ai.output.messages`` tag value to a serialized A365 JSON string. + + Args: + messages_json: The raw JSON string from the span attribute. + + Returns: + Serialized :class:`OutputMessages` JSON string, or ``None`` if the + input is empty or cannot be parsed. + """ + if not messages_json: + return None + + try: + raw = json.loads(messages_json) + except (json.JSONDecodeError, TypeError): + return _wrap_plain_output(messages_json) + + if isinstance(raw, list): + return _map_output_list(raw) + + if isinstance(raw, dict): + # Full Response JSON from ResponseSpanData (model_dump_json) + return _map_response_output(raw) + + return _wrap_plain_output(messages_json) + + +# --------------------------------------------------------------------------- +# Input mapping +# --------------------------------------------------------------------------- + + +def _map_input_list(items: list[Any]) -> str | None: + """Map a list of input items (chat completions or ResponseInputItemParam).""" + chat_messages: list[ChatMessage] = [] + + for item in items: + if not isinstance(item, dict): + continue + + item_type = item.get("type") + + if item_type == "function_call": + # ResponseInputItemParam: function_call → assistant tool call request + name = item.get("name", "") + if name: + parts: list[MessagePart] = [ + ToolCallRequestPart( + name=name, + id=item.get("call_id"), + arguments=item.get("arguments"), + ) + ] + chat_messages.append(ChatMessage(role=MessageRole.ASSISTANT, parts=parts)) + + elif item_type == "function_call_output": + # ResponseInputItemParam: function_call_output → tool response + parts = [ + ToolCallResponsePart( + id=item.get("call_id"), + response=item.get("output"), + ) + ] + chat_messages.append(ChatMessage(role=MessageRole.TOOL, parts=parts)) + + elif item_type == "custom_tool_call": + name = item.get("name", "") + if name: + input_data = item.get("input") + args = json.dumps({"input": input_data}) if input_data is not None else None + parts = [ToolCallRequestPart(name=name, id=item.get("call_id"), arguments=args)] + chat_messages.append(ChatMessage(role=MessageRole.ASSISTANT, parts=parts)) + + elif item_type == "custom_tool_call_output": + parts = [ + ToolCallResponsePart( + id=item.get("call_id"), + response=item.get("output"), + ) + ] + chat_messages.append(ChatMessage(role=MessageRole.TOOL, parts=parts)) + + elif item_type == "message" or "role" in item: + # Standard message (ResponseInputItemParam or chat completions) + mapped = _map_chat_completions_message(item) + if mapped is not None: + chat_messages.append(mapped) + + else: + # Unknown type, try as generic message + mapped = _map_chat_completions_message(item) + if mapped is not None: + chat_messages.append(mapped) + + if not chat_messages: + return None + return serialize_messages(InputMessages(messages=chat_messages)) + + +def _map_chat_completions_message(msg: dict[str, Any]) -> ChatMessage | None: + """Map a single chat-completions-style message dict.""" + role_str = msg.get("role", "") + role = _ROLE_MAP.get(str(role_str).lower(), MessageRole.USER) + parts: list[MessagePart] = [] + + # Tool response message + if role == MessageRole.TOOL: + content = msg.get("content", "") + tool_call_id = msg.get("tool_call_id") + response = str(content) if content else "" + if response or tool_call_id: + parts.append(ToolCallResponsePart(id=tool_call_id, response=response)) + return ChatMessage(role=role, parts=parts) if parts else None + + # Text content (string or list) + content = msg.get("content") + if isinstance(content, str) and content.strip(): + parts.append(TextPart(content=content)) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict): + if item.get("type") in ("input_text", "text"): + text = item.get("text", "") + if text: + parts.append(TextPart(content=text)) + elif item.get("type") == "output_text": + text = item.get("text", "") + if text: + parts.append(TextPart(content=text)) + + # Tool calls + tool_calls = msg.get("tool_calls") + if isinstance(tool_calls, list): + for tc in tool_calls: + if not isinstance(tc, dict): + continue + func = tc.get("function", {}) + if isinstance(func, dict): + name = func.get("name") + if name: + parts.append( + ToolCallRequestPart( + name=name, + id=tc.get("id"), + arguments=func.get("arguments"), + ) + ) + + if not parts: + return None + return ChatMessage(role=role, parts=parts, name=msg.get("name")) + + +# --------------------------------------------------------------------------- +# Output mapping +# --------------------------------------------------------------------------- + + +def _map_output_list(items: list[Any]) -> str | None: + """Map a list of chat-completions-style output messages.""" + output_messages: list[OutputMessage] = [] + + for item in items: + if not isinstance(item, dict): + continue + role_str = item.get("role", "assistant") + role = _ROLE_MAP.get(str(role_str).lower(), MessageRole.ASSISTANT) + parts: list[MessagePart] = [] + + # Tool response + if role == MessageRole.TOOL: + content = item.get("content", "") + tool_call_id = item.get("tool_call_id") + response = str(content) if content else "" + if response or tool_call_id: + parts.append(ToolCallResponsePart(id=tool_call_id, response=response)) + else: + # Text content + content = item.get("content") + if isinstance(content, str) and content.strip(): + parts.append(TextPart(content=content)) + elif isinstance(content, list): + for c in content: + if isinstance(c, dict): + text = c.get("text", "") + if text: + parts.append(TextPart(content=text)) + + # Tool calls + tool_calls = item.get("tool_calls") + if isinstance(tool_calls, list): + for tc in tool_calls: + if not isinstance(tc, dict): + continue + func = tc.get("function", {}) + if isinstance(func, dict): + name = func.get("name") + if name: + parts.append( + ToolCallRequestPart( + name=name, + id=tc.get("id"), + arguments=func.get("arguments"), + ) + ) + + finish_reason = item.get("finish_reason") + if parts: + output_messages.append( + OutputMessage(role=role, parts=parts, finish_reason=finish_reason) + ) + + if not output_messages: + return None + return serialize_messages(OutputMessages(messages=output_messages)) + + +def _map_response_output(response: dict[str, Any]) -> str | None: + """Map a full OpenAI Response JSON to A365 OutputMessages. + + The Response object has ``output: [...]`` containing items with + ``type`` of ``message`` or ``function_call``. + """ + output_items = response.get("output") + if not isinstance(output_items, list): + return None + + output_messages: list[OutputMessage] = [] + + for item in output_items: + if not isinstance(item, Mapping): + continue + item_type = item.get("type") + + if item_type == "message": + parts: list[MessagePart] = [] + role_str = item.get("role", "assistant") + role = _ROLE_MAP.get(str(role_str).lower(), MessageRole.ASSISTANT) + + for content_item in item.get("content", []): + if isinstance(content_item, Mapping): + content_type = content_item.get("type") + if content_type == "output_text": + text = content_item.get("text", "") + if text: + parts.append(TextPart(content=text)) + elif content_type == "refusal": + text = content_item.get("refusal", "") + if text: + parts.append(TextPart(content=text)) + + if parts: + finish_reason = item.get("status") + output_messages.append( + OutputMessage(role=role, parts=parts, finish_reason=finish_reason) + ) + + elif item_type == "function_call": + name = item.get("name", "") + if name: + parts = [ + ToolCallRequestPart( + name=name, + id=item.get("call_id"), + arguments=item.get("arguments"), + ) + ] + output_messages.append( + OutputMessage( + role=MessageRole.ASSISTANT, + parts=parts, + finish_reason="tool_call", + ) + ) + + if not output_messages: + return None + return serialize_messages(OutputMessages(messages=output_messages)) + + +# --------------------------------------------------------------------------- +# Plain-string wrappers +# --------------------------------------------------------------------------- + + +def _wrap_plain_input(text: str) -> str | None: + """Wrap a plain text string as a versioned InputMessages.""" + if not text or not text.strip(): + return None + return serialize_messages( + InputMessages(messages=[ChatMessage(role=MessageRole.USER, parts=[TextPart(content=text)])]) + ) + + +def _wrap_plain_output(text: str) -> str | None: + """Wrap a plain text string as a versioned OutputMessages.""" + if not text or not text.strip(): + return None + return serialize_messages( + OutputMessages( + messages=[OutputMessage(role=MessageRole.ASSISTANT, parts=[TextPart(content=text)])] + ) + ) diff --git a/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_instrumentor.py b/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_instrumentor.py index dae8786e..fec6abf3 100644 --- a/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_instrumentor.py +++ b/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_instrumentor.py @@ -68,4 +68,4 @@ def _instrument(self, **kwargs: Any) -> None: set_trace_processors([OpenAIAgentsTraceProcessor(agent365_tracer)]) def _uninstrument(self, **kwargs: Any) -> None: - pass + set_trace_processors([]) diff --git a/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_processor.py b/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_processor.py index 1a53ca60..eac4b67d 100644 --- a/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_processor.py +++ b/libraries/microsoft-agents-a365-observability-extensions-openai/microsoft_agents_a365/observability/extensions/openai/trace_processor.py @@ -20,6 +20,7 @@ ResponseSpanData, ) from microsoft_agents_a365.observability.core.constants import ( + CHAT_OPERATION_NAME, CUSTOM_PARENT_SPAN_ID_KEY, EXECUTE_TOOL_OPERATION_NAME, GEN_AI_INPUT_MESSAGES_KEY, @@ -49,6 +50,7 @@ from .constants import ( GEN_AI_GRAPH_NODE_PARENT_ID, ) +from .message_mapper import map_input_messages, map_output_messages from .utils import ( capture_input_message, capture_output_message, @@ -258,6 +260,9 @@ def on_span_end(self, span: Span[Any]) -> None: # Clean up tracking self._agent_span_ids.pop(span.span_id, None) + # Map raw messages to A365 versioned format before ending the span + self._apply_message_mapping(otel_span) + end_time: int | None = None if span.ended_at: try: @@ -267,6 +272,26 @@ def on_span_end(self, span: Span[Any]) -> None: otel_span.set_status(status=get_span_status(span)) otel_span.end(end_time) + @staticmethod + def _apply_message_mapping(otel_span: OtelSpan) -> None: + """Map raw ``gen_ai.input/output.messages`` to the A365 versioned format.""" + attrs = otel_span.attributes or {} + operation = attrs.get(GEN_AI_OPERATION_NAME_KEY, "") + if operation not in (INVOKE_AGENT_OPERATION_NAME, CHAT_OPERATION_NAME): + return + + raw_input = attrs.get(GEN_AI_INPUT_MESSAGES_KEY) + if raw_input and isinstance(raw_input, str): + mapped = map_input_messages(raw_input) + if mapped is not None: + otel_span.set_attribute(GEN_AI_INPUT_MESSAGES_KEY, mapped) + + raw_output = attrs.get(GEN_AI_OUTPUT_MESSAGES_KEY) + if raw_output and isinstance(raw_output, str): + mapped = map_output_messages(raw_output) + if mapped is not None: + otel_span.set_attribute(GEN_AI_OUTPUT_MESSAGES_KEY, mapped) + def force_flush(self) -> None: """Forces an immediate flush of all queued spans/traces.""" pass diff --git a/tests/observability/extensions/openai/integration/test_message_format.py b/tests/observability/extensions/openai/integration/test_message_format.py new file mode 100644 index 00000000..705e735c --- /dev/null +++ b/tests/observability/extensions/openai/integration/test_message_format.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Integration tests for OpenAI message format mapping. + +These tests use the real A365 observability pipeline: + configure() → get_tracer_provider() → OpenAIAgentsTraceInstrumentor +with real Azure OpenAI API calls. The message mapping is applied directly +in trace_processor before spans are ended, converting raw OpenAI messages +to the A365 versioned format (v0.1.0) with typed parts. +""" + +import json +import time +from typing import Any + +import pytest +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult + +try: + from agents import Agent, OpenAIChatCompletionsModel, Runner, function_tool + from openai import AsyncAzureOpenAI +except ImportError: + pytest.skip( + "OpenAI agents library and dependencies required for integration tests", + allow_module_level=True, + ) + +from microsoft_agents_a365.observability.core import configure, get_tracer_provider +from microsoft_agents_a365.observability.core.constants import ( + GEN_AI_INPUT_MESSAGES_KEY, + GEN_AI_OUTPUT_MESSAGES_KEY, +) +from microsoft_agents_a365.observability.core.exporters.enriching_span_processor import ( + _EnrichingBatchSpanProcessor, +) +from microsoft_agents_a365.observability.extensions.openai import ( + OpenAIAgentsTraceInstrumentor, +) + + +@function_tool +def get_weather(city: str) -> str: + """Get the current weather for a city. + + Args: + city: The city name to get weather for. + + Returns: + A string describing the weather. + """ + return f"The weather in {city} is sunny, 22°C." + + +class SpanCapturingExporter(SpanExporter): + """Exporter that collects enriched spans in-memory.""" + + def __init__(self) -> None: + self.spans: list[ReadableSpan] = [] + + def export(self, spans: list[ReadableSpan]) -> SpanExportResult: + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + def shutdown(self) -> None: + pass + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + +def _span_to_json(span: ReadableSpan) -> dict[str, object]: + """Convert a ReadableSpan (or EnrichedReadableSpan) to a JSON-serializable dict.""" + try: + ctx = span.get_span_context() + context_dict: dict[str, object] = { + "trace_id": f"0x{ctx.trace_id:032x}", + "span_id": f"0x{ctx.span_id:016x}", + } + except (AttributeError, TypeError): + context_dict = {"note": "context not available on enriched span"} + + try: + parent = span.parent + parent_id = f"0x{parent.span_id:016x}" if parent else None + except (AttributeError, TypeError): + parent_id = None + + events_list: list[dict[str, object]] = [] + for e in getattr(span, "events", None) or []: + events_list.append({ + "name": e.name, + "attributes": dict(e.attributes) if e.attributes else {}, + }) + + links_list: list[dict[str, object]] = [] + for lnk in getattr(span, "links", None) or []: + links_list.append({ + "attributes": dict(lnk.attributes) if lnk.attributes else {}, + }) + + result: dict[str, object] = { + "name": span.name, + "context": context_dict, + "kind": str(getattr(span, "kind", None)), + "parent_id": parent_id, + "status": str(getattr(span, "status", None)), + "attributes": dict(span.attributes) if span.attributes else {}, + "events": events_list, + "links": links_list, + } + + resource = getattr(span, "resource", None) + if resource: + result["resource"] = dict(resource.attributes) if resource.attributes else {} + + scope = getattr(span, "instrumentation_scope", None) + if scope: + result["instrumentation_scope"] = {"name": scope.name, "version": scope.version} + + return result + + +@pytest.mark.integration +class TestOpenAIMessageFormat: + """Capture real OpenAI Agents SDK span attributes after enrichment + and verify the A365 versioned message format.""" + + @pytest.fixture(autouse=True) + def setup_observability(self) -> None: + """Set up A365 observability with OpenAIAgentsTraceInstrumentor.""" + if not hasattr(TestOpenAIMessageFormat, "_exporter"): + configure( + service_name="integration-test-openai-message-format", + service_namespace="agent365-tests", + logger_name="test-logger", + ) + + exporter = SpanCapturingExporter() + provider = get_tracer_provider() + provider.add_span_processor( + _EnrichingBatchSpanProcessor( + exporter, + max_queue_size=100, + schedule_delay_millis=100, + max_export_batch_size=100, + ) + ) + + instrumentor = OpenAIAgentsTraceInstrumentor() + instrumentor.instrument() + + TestOpenAIMessageFormat._exporter = exporter + TestOpenAIMessageFormat._instrumentor = instrumentor + + self.exporter = TestOpenAIMessageFormat._exporter + self.exporter.spans.clear() + + @pytest.fixture + def openai_client(self, azure_openai_config: dict[str, Any]) -> AsyncAzureOpenAI: + """Create a real Azure OpenAI client.""" + return AsyncAzureOpenAI( + api_key=azure_openai_config["api_key"], + api_version=azure_openai_config["api_version"], + azure_endpoint=azure_openai_config["endpoint"], + ) + + def _find_message_spans(self) -> list[ReadableSpan]: + """Find exported spans that have gen_ai.input.messages.""" + get_tracer_provider().force_flush() + time.sleep(0.5) + return [ + s + for s in self.exporter.spans + if s.attributes and GEN_AI_INPUT_MESSAGES_KEY in s.attributes + ] + + @pytest.mark.asyncio + async def test_simple_chat_message_mapping( + self, + openai_client: AsyncAzureOpenAI, + azure_openai_config: dict[str, Any], + ) -> None: + """Simple chat: verify exported spans contain versioned A365 messages.""" + agent = Agent( + name="TestAgent", + instructions="You are a helpful assistant. Reply in one sentence.", + model=OpenAIChatCompletionsModel( + model=azure_openai_config["deployment"], + openai_client=openai_client, + ), + ) + + result = await Runner.run(agent, "What is the capital of France?") + assert result is not None + assert len(result.final_output) > 0 + + # Print ALL spans as full JSON + print(f"\n=== All exported spans ({len(self.exporter.spans)}) ===") + for i, s in enumerate(self.exporter.spans): + span_json = _span_to_json(s) + print(f"\n--- SPAN {i} ---") + print(json.dumps(span_json, indent=2, default=str)) + + message_spans = self._find_message_spans() + assert len(message_spans) > 0, ( + f"No message spans found. All spans: {[s.name for s in self.exporter.spans]}" + ) + + # Verify at least one span has versioned A365 format + found_versioned = False + for span in message_spans: + attrs = dict(span.attributes or {}) + raw_input = attrs.get(GEN_AI_INPUT_MESSAGES_KEY) + if raw_input: + input_data = json.loads(raw_input) + if isinstance(input_data, dict) and input_data.get("version") == "0.1.0": + found_versioned = True + messages = input_data["messages"] + roles = [m["role"] for m in messages] + assert "user" in roles + for msg in messages: + for part in msg["parts"]: + assert "type" in part + + raw_output = attrs.get(GEN_AI_OUTPUT_MESSAGES_KEY) + if raw_output: + output_data = json.loads(raw_output) + if isinstance(output_data, dict) and output_data.get("version") == "0.1.0": + out_messages = output_data["messages"] + assert out_messages[0]["role"] == "assistant" + assert any(p["type"] == "text" for p in out_messages[0]["parts"]) + + assert found_versioned, "Expected at least one span with versioned A365 message format" + + @pytest.mark.asyncio + async def test_tool_call_message_mapping( + self, + openai_client: AsyncAzureOpenAI, + azure_openai_config: dict[str, Any], + ) -> None: + """Tool-calling chat: verify tool_call and tool_call_response parts.""" + agent = Agent( + name="WeatherAgent", + instructions="You are a weather assistant. Always use the get_weather function.", + model=OpenAIChatCompletionsModel( + model=azure_openai_config["deployment"], + openai_client=openai_client, + ), + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Seattle?") + assert result is not None + assert len(result.final_output) > 0 + + # Print ALL spans as full JSON + print(f"\n=== All exported spans ({len(self.exporter.spans)}) ===") + for i, s in enumerate(self.exporter.spans): + span_json = _span_to_json(s) + print(f"\n--- SPAN {i} ---") + print(json.dumps(span_json, indent=2, default=str)) + + message_spans = self._find_message_spans() + assert len(message_spans) > 0 + + # Collect part types from exported (enriched) spans + part_types: set[str] = set() + for span in message_spans: + attrs = dict(span.attributes or {}) + for key in (GEN_AI_INPUT_MESSAGES_KEY, GEN_AI_OUTPUT_MESSAGES_KEY): + raw = attrs.get(key) + if not raw: + continue + data = json.loads(raw) + if isinstance(data, dict) and "messages" in data: + messages = data["messages"] + for msg in messages: + for part in msg.get("parts", []): + part_types.add(part.get("type", "")) + + print(f"\n Exported part types: {part_types}") + assert "text" in part_types, f"Expected text in exported parts: {part_types}" diff --git a/tests/observability/extensions/openai/test_message_mapper.py b/tests/observability/extensions/openai/test_message_mapper.py new file mode 100644 index 00000000..88ca83a8 --- /dev/null +++ b/tests/observability/extensions/openai/test_message_mapper.py @@ -0,0 +1,261 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for the OpenAI message mapper.""" + +import json + +from microsoft_agents_a365.observability.extensions.openai.message_mapper import ( + map_input_messages, + map_output_messages, +) + + +class TestMapInputMessages: + """Tests for map_input_messages.""" + + def test_empty_string_returns_none(self) -> None: + assert map_input_messages("") is None + + def test_whitespace_only_returns_none(self) -> None: + assert map_input_messages(" ") is None + + def test_plain_string_wraps_as_user_message(self) -> None: + result = map_input_messages("Hello world") + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert len(data["messages"]) == 1 + assert data["messages"][0]["role"] == "user" + assert data["messages"][0]["parts"][0]["type"] == "text" + assert data["messages"][0]["parts"][0]["content"] == "Hello world" + + def test_chat_completions_format(self) -> None: + """Standard chat completions format with system + user messages.""" + raw = json.dumps([ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi there"}, + ]) + result = map_input_messages(raw) + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert len(data["messages"]) == 2 + assert data["messages"][0]["role"] == "system" + assert data["messages"][0]["parts"][0]["content"] == "You are helpful." + assert data["messages"][1]["role"] == "user" + assert data["messages"][1]["parts"][0]["content"] == "Hi there" + + def test_chat_completions_with_tool_calls(self) -> None: + """Messages with assistant tool_calls and tool response.""" + raw = json.dumps([ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "function": {"name": "add", "arguments": '{"a":2,"b":2}'}, + } + ], + }, + {"role": "tool", "content": "4", "tool_call_id": "call_123"}, + ]) + result = map_input_messages(raw) + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert len(data["messages"]) == 3 + + # User message + assert data["messages"][0]["role"] == "user" + assert data["messages"][0]["parts"][0]["type"] == "text" + + # Assistant with tool call + assert data["messages"][1]["role"] == "assistant" + assert data["messages"][1]["parts"][0]["type"] == "tool_call" + assert data["messages"][1]["parts"][0]["name"] == "add" + assert data["messages"][1]["parts"][0]["id"] == "call_123" + + # Tool response + assert data["messages"][2]["role"] == "tool" + assert data["messages"][2]["parts"][0]["type"] == "tool_call_response" + assert data["messages"][2]["parts"][0]["id"] == "call_123" + assert data["messages"][2]["parts"][0]["response"] == "4" + + def test_response_input_item_param_format(self) -> None: + """ResponseInputItemParam format with typed items.""" + raw = json.dumps([ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + }, + { + "type": "function_call", + "name": "get_weather", + "call_id": "fc_1", + "arguments": '{"city":"Seattle"}', + }, + {"type": "function_call_output", "call_id": "fc_1", "output": "Sunny, 22C"}, + ]) + result = map_input_messages(raw) + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert len(data["messages"]) == 3 + + # Message + assert data["messages"][0]["role"] == "user" + assert data["messages"][0]["parts"][0]["type"] == "text" + + # Function call + assert data["messages"][1]["role"] == "assistant" + assert data["messages"][1]["parts"][0]["type"] == "tool_call" + assert data["messages"][1]["parts"][0]["name"] == "get_weather" + + # Function call output + assert data["messages"][2]["role"] == "tool" + assert data["messages"][2]["parts"][0]["type"] == "tool_call_response" + assert data["messages"][2]["parts"][0]["response"] == "Sunny, 22C" + + def test_message_without_type_field(self) -> None: + """Messages without explicit 'type' field (EasyInputMessageParam).""" + raw = json.dumps([ + {"role": "user", "content": "Hello"}, + ]) + result = map_input_messages(raw) + assert result is not None + data = json.loads(result) + assert data["messages"][0]["role"] == "user" + + def test_invalid_json_wraps_as_plain_text(self) -> None: + result = map_input_messages("not json {") + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert data["messages"][0]["parts"][0]["content"] == "not json {" + + def test_empty_list_returns_none(self) -> None: + assert map_input_messages("[]") is None + + +class TestMapOutputMessages: + """Tests for map_output_messages.""" + + def test_empty_string_returns_none(self) -> None: + assert map_output_messages("") is None + + def test_plain_string_wraps_as_assistant(self) -> None: + result = map_output_messages("The answer is 42.") + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert data["messages"][0]["role"] == "assistant" + assert data["messages"][0]["parts"][0]["content"] == "The answer is 42." + + def test_chat_completions_output(self) -> None: + """Standard chat completions output with finish_reason.""" + raw = json.dumps([ + { + "role": "assistant", + "content": "Paris is the capital.", + "finish_reason": "stop", + } + ]) + result = map_output_messages(raw) + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert len(data["messages"]) == 1 + msg = data["messages"][0] + assert msg["role"] == "assistant" + assert msg["parts"][0]["type"] == "text" + assert msg["parts"][0]["content"] == "Paris is the capital." + assert msg["finish_reason"] == "stop" + + def test_chat_completions_with_tool_calls(self) -> None: + """Output with tool_calls.""" + raw = json.dumps([ + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "function": {"name": "search", "arguments": '{"q":"test"}'}, + } + ], + "finish_reason": "tool_calls", + } + ]) + result = map_output_messages(raw) + assert result is not None + data = json.loads(result) + msg = data["messages"][0] + assert msg["role"] == "assistant" + assert msg["parts"][0]["type"] == "tool_call" + assert msg["parts"][0]["name"] == "search" + assert msg["finish_reason"] == "tool_calls" + + def test_response_json_format(self) -> None: + """Full OpenAI Response JSON (from model_dump_json).""" + raw = json.dumps({ + "id": "resp_123", + "model": "gpt-4o", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello!"}], + "status": "completed", + } + ], + }) + result = map_output_messages(raw) + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + msg = data["messages"][0] + assert msg["role"] == "assistant" + assert msg["parts"][0]["type"] == "text" + assert msg["parts"][0]["content"] == "Hello!" + + def test_response_json_with_function_call(self) -> None: + """Response JSON with function_call output item.""" + raw = json.dumps({ + "id": "resp_456", + "model": "gpt-4o", + "output": [ + { + "type": "function_call", + "name": "get_weather", + "call_id": "fc_1", + "arguments": '{"city":"NYC"}', + } + ], + }) + result = map_output_messages(raw) + assert result is not None + data = json.loads(result) + msg = data["messages"][0] + assert msg["role"] == "assistant" + assert msg["parts"][0]["type"] == "tool_call" + assert msg["parts"][0]["name"] == "get_weather" + assert msg["finish_reason"] == "tool_call" + + def test_response_json_without_output_returns_none(self) -> None: + """Response JSON without output field.""" + raw = json.dumps({"id": "resp_789", "model": "gpt-4o"}) + assert map_output_messages(raw) is None + + def test_empty_list_returns_none(self) -> None: + assert map_output_messages("[]") is None + + def test_invalid_json_wraps_as_plain_text(self) -> None: + result = map_output_messages("bad json") + assert result is not None + data = json.loads(result) + assert data["version"] == "0.1.0" + assert data["messages"][0]["role"] == "assistant"