From 213fe8debb76f7a321813473dbb5f6a2ead324f6 Mon Sep 17 00:00:00 2001 From: Gavin Aguiar Date: Wed, 21 Jan 2026 12:05:20 -0600 Subject: [PATCH 1/2] MCP tool fix for azurefunctions --- .../agent_framework_azurefunctions/_app.py | 14 ++++---- .../packages/azurefunctions/tests/test_app.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 29e4a7df6a..696548c008 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -608,14 +608,12 @@ async def _handle_mcp_tool_invocation( # Create or parse session ID if thread_id and isinstance(thread_id, str) and thread_id.strip(): - try: - session_id = AgentSessionId.parse(thread_id) - except ValueError as e: - logger.warning( - "Failed to parse AgentSessionId from thread_id '%s': %s. Falling back to new session ID.", - thread_id, - e, - ) + # If thread_id is in @name@key format, extract only the key portion + if thread_id.startswith("@") and "@" in thread_id[1:]: + key = thread_id[1:].split("@", 1)[1] + session_id = AgentSessionId(name=agent_name, key=key) + else: + # Use thread_id as-is for the key session_id = AgentSessionId(name=agent_name, key=thread_id) else: # Generate new session ID diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 29d614e729..c56c045022 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1056,6 +1056,40 @@ async def test_handle_mcp_tool_invocation_runtime_error(self) -> None: with pytest.raises(RuntimeError, match="Agent execution failed"): await app._handle_mcp_tool_invocation("TestAgent", context, client) + async def test_handle_mcp_tool_invocation_ignores_agent_name_in_thread_id(self) -> None: + """Test that MCP tool invocation uses the agent_name parameter, not the name from thread_id.""" + mock_agent = Mock() + mock_agent.name = "PlantAdvisor" + + app = AgentFunctionApp(agents=[mock_agent]) + client = AsyncMock() + + # Mock the entity response + mock_state = Mock() + mock_state.entity_state = { + "schemaVersion": "1.0.0", + "data": {"conversationHistory": []}, + } + client.read_entity_state.return_value = mock_state + + # Thread ID contains a different agent name (@StockAdvisor@poc123) + # but we're invoking PlantAdvisor - it should use PlantAdvisor's entity + context = json.dumps({"arguments": {"query": "test query", "threadId": "@StockAdvisor@test123"}}) + + with patch.object(app, "_get_response_from_entity") as get_response_mock: + get_response_mock.return_value = {"status": "success", "response": "Test response"} + + await app._handle_mcp_tool_invocation("PlantAdvisor", context, client) + + # Verify signal_entity was called with PlantAdvisor's entity, not StockAdvisor's + client.signal_entity.assert_called_once() + call_args = client.signal_entity.call_args + entity_id = call_args[0][0] + + # Entity name should be dafx-PlantAdvisor, not dafx-StockAdvisor + assert entity_id.name == "dafx-PlantAdvisor" + assert entity_id.key == "test123" + def test_health_check_includes_mcp_tool_enabled(self) -> None: """Test that health check endpoint includes mcp_tool_enabled field.""" mock_agent = Mock() From 9c61a7fb10dceac0c17c3f9f4f431566a1f54763 Mon Sep 17 00:00:00 2001 From: Gavin Aguiar Date: Thu, 22 Jan 2026 15:53:39 -0600 Subject: [PATCH 2/2] Moving logic to check for thread id --- .../agent_framework_azurefunctions/_app.py | 14 +++++---- .../agent_framework_azurefunctions/_models.py | 26 ++++++++++------ .../packages/azurefunctions/tests/test_app.py | 30 +++++++++++++++++++ .../azurefunctions/tests/test_models.py | 28 +++++++++++++++++ 4 files changed, 83 insertions(+), 15 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 696548c008..5e64a3feaf 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -608,12 +608,14 @@ async def _handle_mcp_tool_invocation( # Create or parse session ID if thread_id and isinstance(thread_id, str) and thread_id.strip(): - # If thread_id is in @name@key format, extract only the key portion - if thread_id.startswith("@") and "@" in thread_id[1:]: - key = thread_id[1:].split("@", 1)[1] - session_id = AgentSessionId(name=agent_name, key=key) - else: - # Use thread_id as-is for the key + try: + session_id = AgentSessionId.parse(thread_id, agent_name=agent_name) + except ValueError as e: + logger.warning( + "Failed to parse AgentSessionId from thread_id '%s': %s. Falling back to new session ID.", + thread_id, + e, + ) session_id = AgentSessionId(name=agent_name, key=thread_id) else: # Generate new session ID diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index 2ab9667575..ffee3b77fe 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -109,26 +109,34 @@ def __repr__(self) -> str: return f"AgentSessionId(name='{self.name}', key='{self.key}')" @staticmethod - def parse(session_id_string: str) -> AgentSessionId: + def parse(session_id_string: str, agent_name: str | None = None) -> AgentSessionId: """Parses a string representation of an agent session ID. Args: - session_id_string: A string in the form @name@key + session_id_string: A string in the form @name@key, or a plain key string + when agent_name is provided. + agent_name: Optional agent name to use instead of parsing from the string. + If provided, only the key portion is extracted from session_id_string + (for @name@key format) or the entire string is used as the key + (for plain strings). Returns: AgentSessionId instance Raises: - ValueError: If the string format is invalid + ValueError: If the string format is invalid and agent_name is not provided """ - if not session_id_string.startswith("@"): - raise ValueError(f"Invalid agent session ID format: {session_id_string}") + # Check if string is in @name@key format + if session_id_string.startswith("@") and "@" in session_id_string[1:]: + parts = session_id_string[1:].split("@", 1) + name = agent_name if agent_name is not None else parts[0] + return AgentSessionId(name=name, key=parts[1]) - parts = session_id_string[1:].split("@", 1) - if len(parts) != 2: - raise ValueError(f"Invalid agent session ID format: {session_id_string}") + # Plain string format - only valid when agent_name is provided + if agent_name is not None: + return AgentSessionId(name=agent_name, key=session_id_string) - return AgentSessionId(name=parts[0], key=parts[1]) + raise ValueError(f"Invalid agent session ID format: {session_id_string}") class DurableAgentThread(AgentThread): diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index c56c045022..b4b0428f43 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1090,6 +1090,36 @@ async def test_handle_mcp_tool_invocation_ignores_agent_name_in_thread_id(self) assert entity_id.name == "dafx-PlantAdvisor" assert entity_id.key == "test123" + async def test_handle_mcp_tool_invocation_uses_plain_thread_id_as_key(self) -> None: + """Test that a plain thread_id (not in @name@key format) is used as-is for the key.""" + mock_agent = Mock() + mock_agent.name = "TestAgent" + + app = AgentFunctionApp(agents=[mock_agent]) + client = AsyncMock() + + mock_state = Mock() + mock_state.entity_state = { + "schemaVersion": "1.0.0", + "data": {"conversationHistory": []}, + } + client.read_entity_state.return_value = mock_state + + # Plain thread_id without @name@key format + context = json.dumps({"arguments": {"query": "test query", "threadId": "simple-thread-123"}}) + + with patch.object(app, "_get_response_from_entity") as get_response_mock: + get_response_mock.return_value = {"status": "success", "response": "Test response"} + + await app._handle_mcp_tool_invocation("TestAgent", context, client) + + client.signal_entity.assert_called_once() + call_args = client.signal_entity.call_args + entity_id = call_args[0][0] + + assert entity_id.name == "dafx-TestAgent" + assert entity_id.key == "simple-thread-123" + def test_health_check_includes_mcp_tool_enabled(self) -> None: """Test that health check endpoint includes mcp_tool_enabled field.""" mock_agent = Mock() diff --git a/python/packages/azurefunctions/tests/test_models.py b/python/packages/azurefunctions/tests/test_models.py index 74efa9c166..be31f59800 100644 --- a/python/packages/azurefunctions/tests/test_models.py +++ b/python/packages/azurefunctions/tests/test_models.py @@ -120,6 +120,34 @@ def test_parse_round_trip(self) -> None: assert parsed.name == original.name assert parsed.key == original.key + def test_parse_with_agent_name_override(self) -> None: + """Test parsing @name@key format with agent_name parameter overrides the name.""" + session_id = AgentSessionId.parse("@OriginalAgent@test-key-123", agent_name="OverriddenAgent") + + assert session_id.name == "OverriddenAgent" + assert session_id.key == "test-key-123" + + def test_parse_without_agent_name_uses_parsed_name(self) -> None: + """Test parsing @name@key format without agent_name uses name from string.""" + session_id = AgentSessionId.parse("@ParsedAgent@test-key-123") + + assert session_id.name == "ParsedAgent" + assert session_id.key == "test-key-123" + + def test_parse_plain_string_with_agent_name(self) -> None: + """Test parsing plain string with agent_name uses entire string as key.""" + session_id = AgentSessionId.parse("simple-thread-123", agent_name="TestAgent") + + assert session_id.name == "TestAgent" + assert session_id.key == "simple-thread-123" + + def test_parse_plain_string_without_agent_name_raises(self) -> None: + """Test parsing plain string without agent_name raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + AgentSessionId.parse("simple-thread-123") + + assert "Invalid agent session ID format" in str(exc_info.value) + def test_to_entity_name_adds_prefix(self) -> None: """Test that to_entity_name adds the dafx- prefix.""" entity_name = AgentSessionId.to_entity_name("TestAgent")