Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/tools/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,13 @@ func (ts *Toolset) watchConnection(ctx context.Context) {
if !ts.tryRestart(ctx) {
return
}

// After a successful restart, eagerly refresh the tool and prompt
// caches and notify the runtime so it picks up the new server's
// state. The new server may expose a different set of tools/prompts,
// and without this the runtime would keep using its stale copy.
ts.refreshToolCache(ctx)
ts.refreshPromptCache(ctx)
}
}

Expand Down
125 changes: 120 additions & 5 deletions pkg/tools/mcp/reconnect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,46 @@ import (
"github.com/docker/docker-agent/pkg/tools"
)

// startMCPServer creates a minimal MCP server on addr with the given tools
// and returns a function to shut it down.
func startMCPServer(t *testing.T, addr string, mcpTools ...*gomcp.Tool) (shutdown func()) {
t.Helper()

s := gomcp.NewServer(&gomcp.Implementation{Name: "test-server", Version: "1.0.0"}, nil)
for _, tool := range mcpTools {
s.AddTool(tool, func(_ context.Context, _ *gomcp.CallToolRequest) (*gomcp.CallToolResult, error) {
return &gomcp.CallToolResult{
Content: []gomcp.Content{&gomcp.TextContent{Text: "ok-" + tool.Name}},
}, nil
})
}

// Retry Listen until the port is available (e.g. after a server shutdown).
var srvLn net.Listener
require.Eventually(t, func() bool {
var listenErr error
srvLn, listenErr = net.Listen("tcp", addr)
return listenErr == nil
}, 2*time.Second, 50*time.Millisecond, "port %s not available in time", addr)

srv := &http.Server{
Handler: gomcp.NewStreamableHTTPHandler(func(*http.Request) *gomcp.Server { return s }, nil),
}
go func() { _ = srv.Serve(srvLn) }()

return func() { _ = srv.Close() }
}

// allocateAddr returns a free TCP address on localhost.
func allocateAddr(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := ln.Addr().String()
ln.Close()
return addr
}

// TestRemoteReconnectAfterServerRestart verifies that a Toolset backed by a
// real remote (streamable-HTTP) MCP server transparently recovers when the
// server is restarted.
Expand All @@ -33,11 +73,7 @@ import (
func TestRemoteReconnectAfterServerRestart(t *testing.T) {
t.Parallel()

// Use a fixed listener address so we can restart on the same port.
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := ln.Addr().String()
ln.Close() // We only needed the address; close so startServer can bind it.
addr := allocateAddr(t)

var callCount atomic.Int32

Expand Down Expand Up @@ -123,3 +159,82 @@ func TestRemoteReconnectAfterServerRestart(t *testing.T) {
t.Fatal("reconnect did not complete: restarted channel was not closed")
}
}

// TestRemoteReconnectRefreshesTools verifies that after a remote MCP server
// restarts with a different set of tools, the Toolset picks up the new tools
// and notifies the runtime via the toolsChangedHandler.
//
// This is the scenario from https://github.com/docker/docker-agent/issues/2244:
// - Server v1 exposes tools [alpha, shared].
// - Client connects and caches [alpha, shared].
// - Server v1 shuts down; server v2 starts with tools [beta, shared].
// - A tool call to "shared" triggers reconnection.
// - After reconnection, Tools() must return [beta, shared], not the stale [alpha, shared].
// - The toolsChangedHandler must be called so the runtime refreshes its own state.
func TestRemoteReconnectRefreshesTools(t *testing.T) {
t.Parallel()

addr := allocateAddr(t)

// "shared" exists on both servers so we can call it to trigger reconnect.
sharedTool := &gomcp.Tool{Name: "shared", InputSchema: &jsonschema.Schema{Type: "object"}}
alphaTool := &gomcp.Tool{Name: "alpha", InputSchema: &jsonschema.Schema{Type: "object"}}
betaTool := &gomcp.Tool{Name: "beta", InputSchema: &jsonschema.Schema{Type: "object"}}

// --- Start server v1 with tools "alpha" + "shared" ---
shutdown1 := startMCPServer(t, addr, alphaTool, sharedTool)

ts := NewRemoteToolset("ns", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil)

// Track toolsChangedHandler invocations.
toolsChangedCh := make(chan struct{}, 1)
ts.SetToolsChangedHandler(func() {
select {
case toolsChangedCh <- struct{}{}:
default:
}
})

require.NoError(t, ts.Start(t.Context()))

// Verify initial tools.
toolList, err := ts.Tools(t.Context())
require.NoError(t, err)
require.Len(t, toolList, 2)
toolNames := []string{toolList[0].Name, toolList[1].Name}
assert.Contains(t, toolNames, "ns_alpha")
assert.Contains(t, toolNames, "ns_shared")

// --- Shut down server v1, start server v2 with tools "beta" + "shared" ---
shutdown1()

shutdown2 := startMCPServer(t, addr, betaTool, sharedTool)
t.Cleanup(func() {
_ = ts.Stop(t.Context())
shutdown2()
})

// Call "shared" to trigger ErrSessionMissing → reconnect.
result, callErr := ts.callTool(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{Name: "shared", Arguments: "{}"},
})
require.NoError(t, callErr)
assert.Equal(t, "ok-shared", result.Output)

// Wait for the toolsChangedHandler to be called (signals reconnect + refresh).
select {
case <-toolsChangedCh:
// Good — the handler was called.
case <-time.After(30 * time.Second):
t.Fatal("timed out waiting for toolsChangedHandler after reconnect")
}

// Verify the toolset now reports the new server's tools.
toolList, err = ts.Tools(t.Context())
require.NoError(t, err)
require.Len(t, toolList, 2, "expected exactly two tools from the new server")
toolNames = []string{toolList[0].Name, toolList[1].Name}
assert.Contains(t, toolNames, "ns_beta", "expected the new server's tool, got stale tool")
assert.Contains(t, toolNames, "ns_shared")
assert.NotContains(t, toolNames, "ns_alpha", "stale tool from old server should not be present")
}
Loading