From 61c6802b72aa71c59a0a8681a7e179ad8506b341 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 25 Mar 2026 17:48:22 +0100 Subject: [PATCH] fix: refresh tool cache after remote MCP server reconnect After a successful reconnect in watchConnection, eagerly call refreshToolCache so the runtime picks up the new server's tools and the toolsChangedHandler is invoked. Without this, the runtime kept using its stale copy of the tool list. Add TestRemoteReconnectRefreshesTools to verify that when a remote MCP server restarts with a different set of tools the Toolset returns the updated tools and notifies the handler. Extract startMCPServer and allocateAddr test helpers to reduce duplication across reconnect tests. Fixes #2244 Assisted-By: docker-agent --- pkg/tools/mcp/mcp.go | 7 ++ pkg/tools/mcp/reconnect_test.go | 125 ++++++++++++++++++++++++++++++-- 2 files changed, 127 insertions(+), 5 deletions(-) diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index 48cf1f956..46ea4b77e 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -292,6 +292,13 @@ func (ts *Toolset) watchConnection(ctx context.Context) { if !ts.tryRestart(ctx) { return } + + // After a successful restart, eagerly refresh the tool and prompt + // caches and notify the runtime so it picks up the new server's + // state. The new server may expose a different set of tools/prompts, + // and without this the runtime would keep using its stale copy. + ts.refreshToolCache(ctx) + ts.refreshPromptCache(ctx) } } diff --git a/pkg/tools/mcp/reconnect_test.go b/pkg/tools/mcp/reconnect_test.go index a95ff4a9c..eeacc52a6 100644 --- a/pkg/tools/mcp/reconnect_test.go +++ b/pkg/tools/mcp/reconnect_test.go @@ -17,6 +17,46 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) +// startMCPServer creates a minimal MCP server on addr with the given tools +// and returns a function to shut it down. +func startMCPServer(t *testing.T, addr string, mcpTools ...*gomcp.Tool) (shutdown func()) { + t.Helper() + + s := gomcp.NewServer(&gomcp.Implementation{Name: "test-server", Version: "1.0.0"}, nil) + for _, tool := range mcpTools { + s.AddTool(tool, func(_ context.Context, _ *gomcp.CallToolRequest) (*gomcp.CallToolResult, error) { + return &gomcp.CallToolResult{ + Content: []gomcp.Content{&gomcp.TextContent{Text: "ok-" + tool.Name}}, + }, nil + }) + } + + // Retry Listen until the port is available (e.g. after a server shutdown). + var srvLn net.Listener + require.Eventually(t, func() bool { + var listenErr error + srvLn, listenErr = net.Listen("tcp", addr) + return listenErr == nil + }, 2*time.Second, 50*time.Millisecond, "port %s not available in time", addr) + + srv := &http.Server{ + Handler: gomcp.NewStreamableHTTPHandler(func(*http.Request) *gomcp.Server { return s }, nil), + } + go func() { _ = srv.Serve(srvLn) }() + + return func() { _ = srv.Close() } +} + +// allocateAddr returns a free TCP address on localhost. +func allocateAddr(t *testing.T) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := ln.Addr().String() + ln.Close() + return addr +} + // TestRemoteReconnectAfterServerRestart verifies that a Toolset backed by a // real remote (streamable-HTTP) MCP server transparently recovers when the // server is restarted. @@ -33,11 +73,7 @@ import ( func TestRemoteReconnectAfterServerRestart(t *testing.T) { t.Parallel() - // Use a fixed listener address so we can restart on the same port. - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - addr := ln.Addr().String() - ln.Close() // We only needed the address; close so startServer can bind it. + addr := allocateAddr(t) var callCount atomic.Int32 @@ -123,3 +159,82 @@ func TestRemoteReconnectAfterServerRestart(t *testing.T) { t.Fatal("reconnect did not complete: restarted channel was not closed") } } + +// TestRemoteReconnectRefreshesTools verifies that after a remote MCP server +// restarts with a different set of tools, the Toolset picks up the new tools +// and notifies the runtime via the toolsChangedHandler. +// +// This is the scenario from https://github.com/docker/docker-agent/issues/2244: +// - Server v1 exposes tools [alpha, shared]. +// - Client connects and caches [alpha, shared]. +// - Server v1 shuts down; server v2 starts with tools [beta, shared]. +// - A tool call to "shared" triggers reconnection. +// - After reconnection, Tools() must return [beta, shared], not the stale [alpha, shared]. +// - The toolsChangedHandler must be called so the runtime refreshes its own state. +func TestRemoteReconnectRefreshesTools(t *testing.T) { + t.Parallel() + + addr := allocateAddr(t) + + // "shared" exists on both servers so we can call it to trigger reconnect. + sharedTool := &gomcp.Tool{Name: "shared", InputSchema: &jsonschema.Schema{Type: "object"}} + alphaTool := &gomcp.Tool{Name: "alpha", InputSchema: &jsonschema.Schema{Type: "object"}} + betaTool := &gomcp.Tool{Name: "beta", InputSchema: &jsonschema.Schema{Type: "object"}} + + // --- Start server v1 with tools "alpha" + "shared" --- + shutdown1 := startMCPServer(t, addr, alphaTool, sharedTool) + + ts := NewRemoteToolset("ns", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil) + + // Track toolsChangedHandler invocations. + toolsChangedCh := make(chan struct{}, 1) + ts.SetToolsChangedHandler(func() { + select { + case toolsChangedCh <- struct{}{}: + default: + } + }) + + require.NoError(t, ts.Start(t.Context())) + + // Verify initial tools. + toolList, err := ts.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, toolList, 2) + toolNames := []string{toolList[0].Name, toolList[1].Name} + assert.Contains(t, toolNames, "ns_alpha") + assert.Contains(t, toolNames, "ns_shared") + + // --- Shut down server v1, start server v2 with tools "beta" + "shared" --- + shutdown1() + + shutdown2 := startMCPServer(t, addr, betaTool, sharedTool) + t.Cleanup(func() { + _ = ts.Stop(t.Context()) + shutdown2() + }) + + // Call "shared" to trigger ErrSessionMissing → reconnect. + result, callErr := ts.callTool(t.Context(), tools.ToolCall{ + Function: tools.FunctionCall{Name: "shared", Arguments: "{}"}, + }) + require.NoError(t, callErr) + assert.Equal(t, "ok-shared", result.Output) + + // Wait for the toolsChangedHandler to be called (signals reconnect + refresh). + select { + case <-toolsChangedCh: + // Good — the handler was called. + case <-time.After(30 * time.Second): + t.Fatal("timed out waiting for toolsChangedHandler after reconnect") + } + + // Verify the toolset now reports the new server's tools. + toolList, err = ts.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, toolList, 2, "expected exactly two tools from the new server") + toolNames = []string{toolList[0].Name, toolList[1].Name} + assert.Contains(t, toolNames, "ns_beta", "expected the new server's tool, got stale tool") + assert.Contains(t, toolNames, "ns_shared") + assert.NotContains(t, toolNames, "ns_alpha", "stale tool from old server should not be present") +}