diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index 48cf1f956..46ea4b77e 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -292,6 +292,13 @@ func (ts *Toolset) watchConnection(ctx context.Context) { if !ts.tryRestart(ctx) { return } + + // After a successful restart, eagerly refresh the tool and prompt + // caches and notify the runtime so it picks up the new server's + // state. The new server may expose a different set of tools/prompts, + // and without this the runtime would keep using its stale copy. + ts.refreshToolCache(ctx) + ts.refreshPromptCache(ctx) } } diff --git a/pkg/tools/mcp/reconnect_test.go b/pkg/tools/mcp/reconnect_test.go index a95ff4a9c..eeacc52a6 100644 --- a/pkg/tools/mcp/reconnect_test.go +++ b/pkg/tools/mcp/reconnect_test.go @@ -17,6 +17,46 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) +// startMCPServer creates a minimal MCP server on addr with the given tools +// and returns a function to shut it down. +func startMCPServer(t *testing.T, addr string, mcpTools ...*gomcp.Tool) (shutdown func()) { + t.Helper() + + s := gomcp.NewServer(&gomcp.Implementation{Name: "test-server", Version: "1.0.0"}, nil) + for _, tool := range mcpTools { + s.AddTool(tool, func(_ context.Context, _ *gomcp.CallToolRequest) (*gomcp.CallToolResult, error) { + return &gomcp.CallToolResult{ + Content: []gomcp.Content{&gomcp.TextContent{Text: "ok-" + tool.Name}}, + }, nil + }) + } + + // Retry Listen until the port is available (e.g. after a server shutdown). + var srvLn net.Listener + require.Eventually(t, func() bool { + var listenErr error + srvLn, listenErr = net.Listen("tcp", addr) + return listenErr == nil + }, 2*time.Second, 50*time.Millisecond, "port %s not available in time", addr) + + srv := &http.Server{ + Handler: gomcp.NewStreamableHTTPHandler(func(*http.Request) *gomcp.Server { return s }, nil), + } + go func() { _ = srv.Serve(srvLn) }() + + return func() { _ = srv.Close() } +} + +// allocateAddr returns a free TCP address on localhost. +func allocateAddr(t *testing.T) string { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := ln.Addr().String() + ln.Close() + return addr +} + // TestRemoteReconnectAfterServerRestart verifies that a Toolset backed by a // real remote (streamable-HTTP) MCP server transparently recovers when the // server is restarted. @@ -33,11 +73,7 @@ import ( func TestRemoteReconnectAfterServerRestart(t *testing.T) { t.Parallel() - // Use a fixed listener address so we can restart on the same port. - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - addr := ln.Addr().String() - ln.Close() // We only needed the address; close so startServer can bind it. + addr := allocateAddr(t) var callCount atomic.Int32 @@ -123,3 +159,82 @@ func TestRemoteReconnectAfterServerRestart(t *testing.T) { t.Fatal("reconnect did not complete: restarted channel was not closed") } } + +// TestRemoteReconnectRefreshesTools verifies that after a remote MCP server +// restarts with a different set of tools, the Toolset picks up the new tools +// and notifies the runtime via the toolsChangedHandler. +// +// This is the scenario from https://github.com/docker/docker-agent/issues/2244: +// - Server v1 exposes tools [alpha, shared]. +// - Client connects and caches [alpha, shared]. +// - Server v1 shuts down; server v2 starts with tools [beta, shared]. +// - A tool call to "shared" triggers reconnection. +// - After reconnection, Tools() must return [beta, shared], not the stale [alpha, shared]. +// - The toolsChangedHandler must be called so the runtime refreshes its own state. +func TestRemoteReconnectRefreshesTools(t *testing.T) { + t.Parallel() + + addr := allocateAddr(t) + + // "shared" exists on both servers so we can call it to trigger reconnect. + sharedTool := &gomcp.Tool{Name: "shared", InputSchema: &jsonschema.Schema{Type: "object"}} + alphaTool := &gomcp.Tool{Name: "alpha", InputSchema: &jsonschema.Schema{Type: "object"}} + betaTool := &gomcp.Tool{Name: "beta", InputSchema: &jsonschema.Schema{Type: "object"}} + + // --- Start server v1 with tools "alpha" + "shared" --- + shutdown1 := startMCPServer(t, addr, alphaTool, sharedTool) + + ts := NewRemoteToolset("ns", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil) + + // Track toolsChangedHandler invocations. + toolsChangedCh := make(chan struct{}, 1) + ts.SetToolsChangedHandler(func() { + select { + case toolsChangedCh <- struct{}{}: + default: + } + }) + + require.NoError(t, ts.Start(t.Context())) + + // Verify initial tools. + toolList, err := ts.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, toolList, 2) + toolNames := []string{toolList[0].Name, toolList[1].Name} + assert.Contains(t, toolNames, "ns_alpha") + assert.Contains(t, toolNames, "ns_shared") + + // --- Shut down server v1, start server v2 with tools "beta" + "shared" --- + shutdown1() + + shutdown2 := startMCPServer(t, addr, betaTool, sharedTool) + t.Cleanup(func() { + _ = ts.Stop(t.Context()) + shutdown2() + }) + + // Call "shared" to trigger ErrSessionMissing → reconnect. + result, callErr := ts.callTool(t.Context(), tools.ToolCall{ + Function: tools.FunctionCall{Name: "shared", Arguments: "{}"}, + }) + require.NoError(t, callErr) + assert.Equal(t, "ok-shared", result.Output) + + // Wait for the toolsChangedHandler to be called (signals reconnect + refresh). + select { + case <-toolsChangedCh: + // Good — the handler was called. + case <-time.After(30 * time.Second): + t.Fatal("timed out waiting for toolsChangedHandler after reconnect") + } + + // Verify the toolset now reports the new server's tools. + toolList, err = ts.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, toolList, 2, "expected exactly two tools from the new server") + toolNames = []string{toolList[0].Name, toolList[1].Name} + assert.Contains(t, toolNames, "ns_beta", "expected the new server's tool, got stale tool") + assert.Contains(t, toolNames, "ns_shared") + assert.NotContains(t, toolNames, "ns_alpha", "stale tool from old server should not be present") +}