Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions adk/agenticlayer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@

from agenticlayer.config import InteractionType, McpTool, SubAgent
from agenticlayer.constants import HTTP_HEADERS_SESSION_KEY
from agenticlayer.mcp_session_patch import apply_mcp_session_patch

logger = logging.getLogger(__name__)

# Apply the MCP session manager patch on module import
# This fixes the session invalidation issue when MCP servers restart
apply_mcp_session_patch()


def _create_header_provider(propagate_headers: list[str]) -> Callable[[ReadonlyContext], dict[str, str]]:
"""Create a header provider function for a specific MCP server.
Expand Down
180 changes: 180 additions & 0 deletions adk/agenticlayer/mcp_session_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
Monkey-patch for Google ADK's MCPSessionManager to fix session invalidation on server restart.

This module patches both the MCPSessionManager.create_session method and the retry_on_errors
decorator to properly handle the case where an MCP server restarts and loses session state.

Root Cause:
-----------
When an MCP server restarts:
1. The server loses all session state
2. Client-side session streams remain open (not disconnected)
3. Cached session appears valid because _is_session_disconnected() only checks local streams
4. Server returns 404 for requests with old session IDs
5. Tool calls time out waiting for responses
6. On retry, the same bad cached session is reused

The Fix:
--------
We patch the retry_on_errors decorator to:
1. Detect when an error occurs during MCP operations
2. Force-close the streams of the cached session
3. This makes _is_session_disconnected() return True
4. On retry, create_session() sees the session is disconnected and creates a fresh one

This is a temporary workaround until the fix is merged upstream in Google ADK.

Issue: https://github.com/agentic-layer/sdk-python/issues/XXX
"""

import functools
import logging
from typing import Any, Callable

from google.adk.tools.mcp_tool import mcp_toolset
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager

logger = logging.getLogger(__name__)

# Store the original decorator
_original_retry_on_errors = None


def _patched_retry_on_errors(func: Callable[..., Any]) -> Callable[..., Any]:
"""Patched version of retry_on_errors that invalidates sessions on error.

This wraps the original decorator and adds logic to close cached session streams
when an error occurs, ensuring the session is marked as disconnected for retry.
"""
# First, apply the original decorator if it exists
if _original_retry_on_errors:
func = _original_retry_on_errors(func)

@functools.wraps(func)
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await func(self, *args, **kwargs)
except Exception as e:
# When an error occurs, try to invalidate any cached MCP sessions
# by closing their streams, so retry gets a fresh session
if hasattr(self, "_mcp_session_manager"):
session_manager: MCPSessionManager = self._mcp_session_manager
logger.info(
f"[PATCH] Error in MCP operation ({func.__name__}), invalidating cached sessions: {type(e).__name__}"
)

# Access the session cache and close all sessions' streams
if hasattr(session_manager, "_sessions"):
try:
# Use the lock to safely access sessions
num_sessions = len(session_manager._sessions)
logger.debug(f"[PATCH] Found {num_sessions} cached sessions to invalidate")

# We can't use the lock here because we're already in an async context
# and the lock might be held. Instead, just try to close streams.
for session_key, (session, _, _) in list(session_manager._sessions.items()):
try:
logger.debug(f"[PATCH] Invalidating session: {session_key}")

# Force-close the read stream
if hasattr(session, "_read_stream"):
read_stream = session._read_stream
logger.debug(
f"[PATCH] Read stream type: {type(read_stream).__name__}, has aclose: {hasattr(read_stream, 'aclose')}"
)
if hasattr(read_stream, "aclose"):
await read_stream.aclose()
logger.debug("[PATCH] Closed read stream via aclose()")
elif hasattr(read_stream, "close"):
read_stream.close()
logger.debug("[PATCH] Closed read stream via close()")
else:
logger.debug("[PATCH] Session has no _read_stream")

# Force-close the write stream
if hasattr(session, "_write_stream"):
write_stream = session._write_stream
logger.debug(
f"[PATCH] Write stream type: {type(write_stream).__name__}, has aclose: {hasattr(write_stream, 'aclose')}"
)
if hasattr(write_stream, "aclose"):
await write_stream.aclose()
logger.debug("[PATCH] Closed write stream via aclose()")
elif hasattr(write_stream, "close"):
write_stream.close()
logger.debug("[PATCH] Closed write stream via close()")
else:
logger.debug("[PATCH] Session has no _write_stream")

logger.info(f"[PATCH] Successfully invalidated session {session_key}")
except Exception as close_err:
logger.warning(f"[PATCH] Could not close streams for {session_key}: {close_err}")
except Exception as invalidate_err:
logger.error(f"[PATCH] Error invalidating sessions: {invalidate_err}", exc_info=True)
else:
logger.debug("[PATCH] Session manager has no _sessions attribute")
else:
logger.debug(f"[PATCH] Object {type(self).__name__} has no _mcp_session_manager attribute")

# Re-raise the exception so the original decorator can handle retry
raise

return wrapper


def apply_mcp_session_patch() -> None:
"""Apply the monkey-patch to the retry_on_errors decorator.

This should be called once during application initialization before
any MCP tools are created.
"""
global _original_retry_on_errors

if _original_retry_on_errors is None:
logger.info("Applying MCP session manager patch for server restart handling")

# Store the original decorator
from google.adk.tools.mcp_tool import mcp_session_manager
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset

_original_retry_on_errors = mcp_session_manager.retry_on_errors

# Replace the decorator in the module
mcp_session_manager.retry_on_errors = _patched_retry_on_errors

# Re-decorate the methods in McpToolset that use @retry_on_errors
# Find all methods that were decorated and re-decorate them
for attr_name in dir(McpToolset):
if not attr_name.startswith("_"):
attr = getattr(McpToolset, attr_name)
if callable(attr) and hasattr(attr, "__wrapped__"):
# This is likely a decorated method
# Re-decorate it with our patched decorator
original_func = attr.__wrapped__
setattr(McpToolset, attr_name, _patched_retry_on_errors(original_func))
logger.debug(f"Re-decorated McpToolset.{attr_name}")

logger.info("MCP session manager patch applied successfully")
else:
logger.warning("MCP session manager patch already applied")


def remove_mcp_session_patch() -> None:
"""Remove the monkey-patch and restore original behavior.

This is primarily for testing purposes.
"""
global _original_retry_on_errors

if _original_retry_on_errors is not None:
logger.info("Removing MCP session manager patch")

from google.adk.tools.mcp_tool import mcp_session_manager

mcp_session_manager.retry_on_errors = _original_retry_on_errors

if hasattr(mcp_toolset, "retry_on_errors"):
mcp_toolset.retry_on_errors = _original_retry_on_errors

_original_retry_on_errors = None
logger.info("MCP session manager patch removed")
157 changes: 157 additions & 0 deletions adk/tests/test_agent_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,160 @@ def echo(message: str) -> str:
assert "x-external-token" not in headers_lower, "X-External-Token should not be sent"
assert "x-api-key" not in headers_lower, "X-API-Key should not be sent"
assert "authorization" not in headers_lower, "Authorization should not be sent"

@pytest.mark.asyncio
async def test_mcp_server_restart_causes_communication_failure(
self,
app_factory: Any,
agent_factory: Any,
llm_controller: LLMMockController,
respx_mock: respx.MockRouter,
) -> None:
"""Test that agent fails to communicate with MCP server after server restart.

This test reproduces the issue where:
1. Agent successfully calls an MCP server tool
2. MCP server restarts (losing session state)
3. Agent tries to call the tool again with the SAME agent instance
4. Communication fails because the cached MCP session is no longer valid

This reproduces the real-world scenario where a server restarts and loses
all session state, but the client still has cached session objects.

Expected behavior after this issue is fixed:
- The agent should automatically detect the invalid session
- Create a new session with the restarted server
- Successfully complete the second tool call
"""

# Given: Mock LLM to call 'add' tool twice
llm_controller.respond_with_tool_call(
pattern="first call",
tool_name="add",
tool_args={"a": 2, "b": 3},
final_message="First calculation done: 5",
)
llm_controller.respond_with_tool_call(
pattern="second call",
tool_name="add",
tool_args={"a": 10, "b": 20},
final_message="Second calculation done: 30",
)

# Given: MCP server with 'add' tool
mcp = FastMCP("Calculator")

@mcp.tool()
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

mcp_server_url = "http://test-mcp-restart.local"
mcp_app = mcp.http_app(path="/mcp")

# Track server state to simulate restart
server_state: dict[str, Any] = {
"accept_old_sessions": True,
"old_session_ids": set(),
}

async with LifespanManager(mcp_app) as mcp_manager:
# Handler that can reject old session IDs after "restart"
async def session_handler(request: httpx.Request) -> httpx.Response:
url_str = str(request.url)

# Extract session ID from URL path (e.g., /mcp/messages/SESSION_ID)
session_id = None
if "/messages/" in url_str:
parts = url_str.split("/messages/")
if len(parts) > 1:
session_id = parts[1].split("/")[0].split("?")[0]

# Check if we should reject this session
if session_id and session_id in server_state["old_session_ids"]:
if not server_state["accept_old_sessions"]:
# Server has "restarted" and doesn't recognize old sessions
print(f" [Server] Rejecting old session: {session_id}")
return httpx.Response(
status_code=404,
json={"error": "Session not found"},
headers={"content-type": "application/json"},
)

# Forward request to MCP server
transport = httpx.ASGITransport(app=mcp_manager.app)
async with httpx.AsyncClient(transport=transport, base_url=mcp_server_url) as client:
response = await client.request(
method=request.method,
url=str(request.url),
headers=request.headers,
content=request.content,
)

# Track successful session IDs
if session_id and response.status_code == 200:
server_state["old_session_ids"].add(session_id)

return response

respx_mock.route(host="test-mcp-restart.local").mock(side_effect=session_handler)

# When: Create agent with MCP tool
test_agent = agent_factory("test_agent")
tools = [McpTool(name="calc", url=AnyHttpUrl(f"{mcp_server_url}/mcp"), timeout=30)]

async with app_factory(test_agent, tools=tools) as app:
client = TestClient(app)

# ===== FIRST CALL =====
print("\n=== FIRST CALL: Should succeed ===")
response1 = client.post("", json=create_send_message_request("first call: Calculate 2 + 3"))

# Then: Verify first call succeeded
assert response1.status_code == 200
result1 = verify_jsonrpc_response(response1.json())
assert result1["status"]["state"] == "completed", "First task should complete successfully"
print("✓ First call completed successfully")

# ===== SIMULATE SERVER RESTART =====
print("\n=== SIMULATING SERVER RESTART ===")
print(" Server will reject all previously established sessions")
server_state["accept_old_sessions"] = False

# ===== SECOND CALL =====
print("\n=== SECOND CALL: Should fail with current implementation ===")
response2 = client.post("", json=create_send_message_request("second call: Calculate 10 + 20"))

# Then: Verify the issue is reproduced
assert response2.status_code == 200, "A2A response should be 200 even if task failed"
result2 = verify_jsonrpc_response(response2.json())

task_state = result2["status"]["state"]
print(f"✓ Task state: {task_state}")

# Document the current behavior: task should fail
if task_state == "failed":
print("\n=== ISSUE SUCCESSFULLY REPRODUCED ===")
print("The agent failed to communicate with the MCP server after restart.")
print("This is the bug we're documenting.")
error_msg = result2["status"].get("message", {}).get("parts", [{}])[0].get("text", "")
print(f"Error message: {error_msg}")

# Assert that we got the expected failure
assert "Failed to create MCP session" in error_msg or "MCP" in error_msg, (
f"Expected MCP session error, got: {error_msg}"
)

elif task_state == "completed":
print("\n=== ISSUE NOT REPRODUCED (or already fixed) ===")
print("The agent successfully reconnected despite the server restart.")
print("This suggests the MCP client auto-recovery is working.")
# If this happens, the issue might already be fixed or the test isn't right
# For now, we'll fail the test to investigate
pytest.fail(
"Expected task to fail after server restart, but it completed successfully. "
"Either the bug is already fixed, or the test needs adjustment."
)
else:
print(f"\n=== UNEXPECTED STATE: {task_state} ===")
pytest.fail(f"Unexpected task state: {task_state}")