diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index f75a71571..273a71bd7 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -58,6 +58,11 @@ type Toolset struct { // toolsChangedHandler is called after the tool cache is refreshed // following a ToolListChanged notification from the server. toolsChangedHandler func() + + // restarted is closed and replaced whenever the connection is + // successfully restarted by watchConnection, allowing callers + // waiting on a reconnect to be unblocked. + restarted chan struct{} } // invalidateCache clears the cached tools and prompts and bumps the @@ -68,6 +73,10 @@ func (ts *Toolset) invalidateCache() { ts.cacheGen++ } +// sessionMissingRetryTimeout is the maximum time to wait for watchConnection +// to restart the MCP server after an ErrSessionMissing error. +const sessionMissingRetryTimeout = 35 * time.Second + var ( _ tools.ToolSet = (*Toolset)(nil) _ tools.Describer = (*Toolset)(nil) @@ -145,6 +154,8 @@ func (ts *Toolset) Start(ctx context.Context) error { return nil } + ts.restarted = make(chan struct{}) + if err := ts.doStart(ctx); err != nil { if errors.Is(err, errServerUnavailable) { // The server is unreachable but the error is non-fatal. @@ -307,6 +318,9 @@ func (ts *Toolset) tryRestart(ctx context.Context) bool { } ts.started = true + // Signal anyone waiting for a reconnect. + close(ts.restarted) + ts.restarted = make(chan struct{}) ts.mu.Unlock() slog.Info("MCP server restarted successfully", "server", ts.logID) @@ -438,6 +452,16 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool request.Arguments = args resp, err := ts.mcpClient.CallTool(ctx, request) + + // If the server lost our session (e.g. it restarted), force a + // reconnection and retry the call once. + if errors.Is(err, mcp.ErrSessionMissing) { + slog.Warn("MCP session missing, forcing reconnect and retrying", "tool", toolCall.Function.Name, "server", ts.logID) + if waitErr := ts.forceReconnectAndWait(ctx); waitErr != nil { + return nil, fmt.Errorf("failed to reconnect after session loss: %w", waitErr) + } + resp, err = ts.mcpClient.CallTool(ctx, request) + } if err != nil { if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { slog.Debug("CallTool canceled by context", "tool", toolCall.Function.Name) @@ -453,6 +477,33 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool return result, nil } +// forceReconnectAndWait closes the current session to trigger watchConnection's +// restart logic, then waits for the reconnection to complete. +func (ts *Toolset) forceReconnectAndWait(ctx context.Context) error { + ts.mu.Lock() + restartCh := ts.restarted + alreadyRestarting := !ts.started + ts.mu.Unlock() + + if !alreadyRestarting { + // Force-close the session so that Wait() returns and watchConnection + // kicks in with its restart loop. Skip this if watchConnection has + // already detected the disconnect (started==false) to avoid killing + // a connection that tryRestart may be establishing concurrently. + _ = ts.mcpClient.Close(context.WithoutCancel(ctx)) + } + + // Wait for watchConnection to complete a successful restart. + select { + case <-restartCh: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-time.After(sessionMissingRetryTimeout): + return errors.New("timed out waiting for MCP server reconnection") + } +} + func (ts *Toolset) Stop(ctx context.Context) error { slog.Debug("Stopping MCP toolset", "server", ts.logID) diff --git a/pkg/tools/mcp/mcp_test.go b/pkg/tools/mcp/mcp_test.go index d75164950..f9fb1ff7f 100644 --- a/pkg/tools/mcp/mcp_test.go +++ b/pkg/tools/mcp/mcp_test.go @@ -2,7 +2,10 @@ package mcp import ( "context" + "fmt" "iter" + "sync" + "sync/atomic" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -51,6 +54,47 @@ func (m *mockMCPClient) Wait() error { return nil } func (m *mockMCPClient) Close(context.Context) error { return nil } +// reconnectableMockClient extends mockMCPClient with reconnect simulation. +type reconnectableMockClient struct { + mockMCPClient + + mu sync.Mutex + waitCh chan struct{} // closed when Close is called, unblocking Wait +} + +func newReconnectableMock() *reconnectableMockClient { + return &reconnectableMockClient{ + waitCh: make(chan struct{}), + } +} + +func (m *reconnectableMockClient) Initialize(context.Context, *mcp.InitializeRequest) (*mcp.InitializeResult, error) { + m.mu.Lock() + m.waitCh = make(chan struct{}) // fresh channel for each session + m.mu.Unlock() + return &mcp.InitializeResult{}, nil +} + +func (m *reconnectableMockClient) Wait() error { + m.mu.Lock() + ch := m.waitCh + m.mu.Unlock() + <-ch + return nil +} + +func (m *reconnectableMockClient) Close(context.Context) error { + m.mu.Lock() + // Close the wait channel to unblock Wait(). + select { + case <-m.waitCh: + default: + close(m.waitCh) + } + m.mu.Unlock() + return nil +} + func TestCallToolStripsNullArguments(t *testing.T) { t.Parallel() @@ -251,3 +295,46 @@ func TestProcessMCPContent(t *testing.T) { func callToolResult(content ...mcp.Content) *mcp.CallToolResult { return &mcp.CallToolResult{Content: content} } + +func TestCallToolRecoversFromErrSessionMissing(t *testing.T) { + t.Parallel() + + var callCount atomic.Int32 + + mock := newReconnectableMock() + mock.callToolFn = func(_ context.Context, _ *mcp.CallToolParams) (*mcp.CallToolResult, error) { + n := callCount.Add(1) + if n == 1 { + // First call: simulate server restart by returning ErrSessionMissing. + return nil, fmt.Errorf("tools/call: %w", mcp.ErrSessionMissing) + } + // Second call (after reconnect): succeed. + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "recovered"}}, + }, nil + } + + ts := &Toolset{ + started: true, + mcpClient: mock, + logID: "test-server", + restarted: make(chan struct{}), + } + + // Start the watchConnection goroutine as Start() would. + go ts.watchConnection(t.Context()) + + result, err := ts.callTool(t.Context(), tools.ToolCall{ + Function: tools.FunctionCall{ + Name: "test_tool", + Arguments: `{"key": "value"}`, + }, + }) + + require.NoError(t, err) + assert.Equal(t, "recovered", result.Output) + assert.Equal(t, int32(2), callCount.Load(), "expected exactly 2 CallTool invocations (1 failed + 1 retry)") + + // Clean up: stop the watcher. + _ = ts.Stop(t.Context()) +} diff --git a/pkg/tools/mcp/reconnect_test.go b/pkg/tools/mcp/reconnect_test.go new file mode 100644 index 000000000..a95ff4a9c --- /dev/null +++ b/pkg/tools/mcp/reconnect_test.go @@ -0,0 +1,125 @@ +package mcp + +import ( + "context" + "fmt" + "net" + "net/http" + "sync/atomic" + "testing" + "time" + + "github.com/google/jsonschema-go/jsonschema" + gomcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/tools" +) + +// TestRemoteReconnectAfterServerRestart verifies that a Toolset backed by a +// real remote (streamable-HTTP) MCP server transparently recovers when the +// server is restarted. +// +// The scenario: +// 1. Start a minimal MCP server with a "ping" tool. +// 2. Connect a Toolset, call "ping" — succeeds. +// 3. Shut down the server (simulates crash / restart). +// 4. Start a **new** server on the same address. +// 5. Call "ping" again — this must succeed after automatic reconnection. +// +// Without the ErrSessionMissing recovery logic the second call would fail +// because the new server does not know the old session ID. +func TestRemoteReconnectAfterServerRestart(t *testing.T) { + t.Parallel() + + // Use a fixed listener address so we can restart on the same port. + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := ln.Addr().String() + ln.Close() // We only needed the address; close so startServer can bind it. + + var callCount atomic.Int32 + + // startServer creates a minimal MCP server on addr with a "ping" tool + // and returns a function to shut it down. + startServer := func(t *testing.T) (shutdown func()) { + t.Helper() + + s := gomcp.NewServer(&gomcp.Implementation{Name: "test-server", Version: "1.0.0"}, nil) + s.AddTool(&gomcp.Tool{ + Name: "ping", + InputSchema: &jsonschema.Schema{Type: "object"}, + }, func(_ context.Context, _ *gomcp.CallToolRequest) (*gomcp.CallToolResult, error) { + n := callCount.Add(1) + return &gomcp.CallToolResult{ + Content: []gomcp.Content{&gomcp.TextContent{Text: fmt.Sprintf("pong-%d", n)}}, + }, nil + }) + + // Retry Listen until the port is available (e.g. after a server shutdown). + var srvLn net.Listener + require.Eventually(t, func() bool { + var listenErr error + srvLn, listenErr = net.Listen("tcp", addr) + return listenErr == nil + }, 2*time.Second, 50*time.Millisecond, "port %s not available in time", addr) + + srv := &http.Server{ + Handler: gomcp.NewStreamableHTTPHandler(func(*http.Request) *gomcp.Server { return s }, nil), + } + go func() { _ = srv.Serve(srvLn) }() + + return func() { _ = srv.Close() } + } + + callPing := func(t *testing.T, ts *Toolset) string { + t.Helper() + result, callErr := ts.callTool(t.Context(), tools.ToolCall{ + Function: tools.FunctionCall{Name: "ping", Arguments: "{}"}, + }) + require.NoError(t, callErr) + return result.Output + } + + // --- Step 1–2: Start first server, connect toolset --- + shutdown1 := startServer(t) + + ts := NewRemoteToolset("test", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil) + require.NoError(t, ts.Start(t.Context())) + + toolList, err := ts.Tools(t.Context()) + require.NoError(t, err) + require.Len(t, toolList, 1) + assert.Equal(t, "test_ping", toolList[0].Name) + + // --- Step 3: Call succeeds on original server --- + assert.Equal(t, "pong-1", callPing(t, ts)) + + // --- Step 4: Shut down the server --- + shutdown1() + + // Capture the current restarted channel before the reconnect + ts.mu.Lock() + restartedCh := ts.restarted + ts.mu.Unlock() + + // --- Step 5–6: Start a fresh server, call again --- + shutdown2 := startServer(t) + t.Cleanup(func() { + _ = ts.Stop(t.Context()) + shutdown2() + }) + + // This call triggers ErrSessionMissing recovery and must succeed transparently. + assert.Equal(t, "pong-2", callPing(t, ts)) + + // Verify that watchConnection actually restarted the connection by checking + // that the restarted channel was closed (signaling reconnect completion). + select { + case <-restartedCh: + // Success: the channel was closed, meaning reconnect happened + case <-time.After(100 * time.Millisecond): + t.Fatal("reconnect did not complete: restarted channel was not closed") + } +}