Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions pkg/tools/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ type Toolset struct {
// toolsChangedHandler is called after the tool cache is refreshed
// following a ToolListChanged notification from the server.
toolsChangedHandler func()

// restarted is closed and replaced whenever the connection is
// successfully restarted by watchConnection, allowing callers
// waiting on a reconnect to be unblocked.
restarted chan struct{}
}

// invalidateCache clears the cached tools and prompts and bumps the
Expand All @@ -68,6 +73,10 @@ func (ts *Toolset) invalidateCache() {
ts.cacheGen++
}

// sessionMissingRetryTimeout is the maximum time to wait for watchConnection
// to restart the MCP server after an ErrSessionMissing error.
const sessionMissingRetryTimeout = 35 * time.Second

var (
_ tools.ToolSet = (*Toolset)(nil)
_ tools.Describer = (*Toolset)(nil)
Expand Down Expand Up @@ -145,6 +154,8 @@ func (ts *Toolset) Start(ctx context.Context) error {
return nil
}

ts.restarted = make(chan struct{})

if err := ts.doStart(ctx); err != nil {
if errors.Is(err, errServerUnavailable) {
// The server is unreachable but the error is non-fatal.
Expand Down Expand Up @@ -307,6 +318,9 @@ func (ts *Toolset) tryRestart(ctx context.Context) bool {
}

ts.started = true
// Signal anyone waiting for a reconnect.
close(ts.restarted)
ts.restarted = make(chan struct{})
ts.mu.Unlock()

slog.Info("MCP server restarted successfully", "server", ts.logID)
Expand Down Expand Up @@ -438,6 +452,16 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool
request.Arguments = args

resp, err := ts.mcpClient.CallTool(ctx, request)

// If the server lost our session (e.g. it restarted), force a
// reconnection and retry the call once.
if errors.Is(err, mcp.ErrSessionMissing) {
slog.Warn("MCP session missing, forcing reconnect and retrying", "tool", toolCall.Function.Name, "server", ts.logID)
if waitErr := ts.forceReconnectAndWait(ctx); waitErr != nil {
return nil, fmt.Errorf("failed to reconnect after session loss: %w", waitErr)
}
resp, err = ts.mcpClient.CallTool(ctx, request)
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
slog.Debug("CallTool canceled by context", "tool", toolCall.Function.Name)
Expand All @@ -453,6 +477,33 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool
return result, nil
}

// forceReconnectAndWait closes the current session to trigger watchConnection's
// restart logic, then waits for the reconnection to complete.
func (ts *Toolset) forceReconnectAndWait(ctx context.Context) error {
ts.mu.Lock()
restartCh := ts.restarted
alreadyRestarting := !ts.started
ts.mu.Unlock()

if !alreadyRestarting {
// Force-close the session so that Wait() returns and watchConnection
// kicks in with its restart loop. Skip this if watchConnection has
// already detected the disconnect (started==false) to avoid killing
// a connection that tryRestart may be establishing concurrently.
_ = ts.mcpClient.Close(context.WithoutCancel(ctx))
}

// Wait for watchConnection to complete a successful restart.
select {
case <-restartCh:
return nil
case <-ctx.Done():
return ctx.Err()
case <-time.After(sessionMissingRetryTimeout):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ MEDIUM: Timeout not coordinated with retry backoff

The sessionMissingRetryTimeout is set to 30 seconds, but tryRestart() uses exponential backoff that can take up to 31 seconds total (1+2+4+8+16 seconds across 5 retry attempts).

This creates a race condition where:

  1. A tool call encounters ErrSessionMissing and calls forceReconnectAndWait()
  2. watchConnection is in the middle of a backoff sleep (e.g., the 16-second sleep on the 5th retry)
  3. The tool call times out at 30 seconds and returns an error to the user
  4. A second later, the reconnect completes successfully
  5. The next identical tool call succeeds

Impact: Non-deterministic failures where tool calls fail with "timed out waiting for MCP server reconnection" even though the server successfully reconnects moments later.

Recommendation: Either increase sessionMissingRetryTimeout to 35-40 seconds to account for the maximum backoff duration, or coordinate the timeout with the actual retry logic (e.g., calculate remaining backoff time).

return errors.New("timed out waiting for MCP server reconnection")
}
}

func (ts *Toolset) Stop(ctx context.Context) error {
slog.Debug("Stopping MCP toolset", "server", ts.logID)

Expand Down
87 changes: 87 additions & 0 deletions pkg/tools/mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package mcp

import (
"context"
"fmt"
"iter"
"sync"
"sync/atomic"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"
Expand Down Expand Up @@ -51,6 +54,47 @@ func (m *mockMCPClient) Wait() error { return nil }

func (m *mockMCPClient) Close(context.Context) error { return nil }

// reconnectableMockClient extends mockMCPClient with reconnect simulation.
type reconnectableMockClient struct {
mockMCPClient

mu sync.Mutex
waitCh chan struct{} // closed when Close is called, unblocking Wait
}

func newReconnectableMock() *reconnectableMockClient {
return &reconnectableMockClient{
waitCh: make(chan struct{}),
}
}

func (m *reconnectableMockClient) Initialize(context.Context, *mcp.InitializeRequest) (*mcp.InitializeResult, error) {
m.mu.Lock()
m.waitCh = make(chan struct{}) // fresh channel for each session
m.mu.Unlock()
return &mcp.InitializeResult{}, nil
}

func (m *reconnectableMockClient) Wait() error {
m.mu.Lock()
ch := m.waitCh
m.mu.Unlock()
<-ch
return nil
}

func (m *reconnectableMockClient) Close(context.Context) error {
m.mu.Lock()
// Close the wait channel to unblock Wait().
select {
case <-m.waitCh:
default:
close(m.waitCh)
}
m.mu.Unlock()
return nil
}

func TestCallToolStripsNullArguments(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -251,3 +295,46 @@ func TestProcessMCPContent(t *testing.T) {
func callToolResult(content ...mcp.Content) *mcp.CallToolResult {
return &mcp.CallToolResult{Content: content}
}

func TestCallToolRecoversFromErrSessionMissing(t *testing.T) {
t.Parallel()

var callCount atomic.Int32

mock := newReconnectableMock()
mock.callToolFn = func(_ context.Context, _ *mcp.CallToolParams) (*mcp.CallToolResult, error) {
n := callCount.Add(1)
if n == 1 {
// First call: simulate server restart by returning ErrSessionMissing.
return nil, fmt.Errorf("tools/call: %w", mcp.ErrSessionMissing)
}
// Second call (after reconnect): succeed.
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{Text: "recovered"}},
}, nil
}

ts := &Toolset{
started: true,
mcpClient: mock,
logID: "test-server",
restarted: make(chan struct{}),
}

// Start the watchConnection goroutine as Start() would.
go ts.watchConnection(t.Context())

result, err := ts.callTool(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{
Name: "test_tool",
Arguments: `{"key": "value"}`,
},
})

require.NoError(t, err)
assert.Equal(t, "recovered", result.Output)
assert.Equal(t, int32(2), callCount.Load(), "expected exactly 2 CallTool invocations (1 failed + 1 retry)")

// Clean up: stop the watcher.
_ = ts.Stop(t.Context())
}
125 changes: 125 additions & 0 deletions pkg/tools/mcp/reconnect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package mcp

import (
"context"
"fmt"
"net"
"net/http"
"sync/atomic"
"testing"
"time"

"github.com/google/jsonschema-go/jsonschema"
gomcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/tools"
)

// TestRemoteReconnectAfterServerRestart verifies that a Toolset backed by a
// real remote (streamable-HTTP) MCP server transparently recovers when the
// server is restarted.
//
// The scenario:
// 1. Start a minimal MCP server with a "ping" tool.
// 2. Connect a Toolset, call "ping" — succeeds.
// 3. Shut down the server (simulates crash / restart).
// 4. Start a **new** server on the same address.
// 5. Call "ping" again — this must succeed after automatic reconnection.
//
// Without the ErrSessionMissing recovery logic the second call would fail
// because the new server does not know the old session ID.
func TestRemoteReconnectAfterServerRestart(t *testing.T) {
t.Parallel()

// Use a fixed listener address so we can restart on the same port.
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := ln.Addr().String()
ln.Close() // We only needed the address; close so startServer can bind it.

var callCount atomic.Int32

// startServer creates a minimal MCP server on addr with a "ping" tool
// and returns a function to shut it down.
startServer := func(t *testing.T) (shutdown func()) {
t.Helper()

s := gomcp.NewServer(&gomcp.Implementation{Name: "test-server", Version: "1.0.0"}, nil)
s.AddTool(&gomcp.Tool{
Name: "ping",
InputSchema: &jsonschema.Schema{Type: "object"},
}, func(_ context.Context, _ *gomcp.CallToolRequest) (*gomcp.CallToolResult, error) {
n := callCount.Add(1)
return &gomcp.CallToolResult{
Content: []gomcp.Content{&gomcp.TextContent{Text: fmt.Sprintf("pong-%d", n)}},
}, nil
})

// Retry Listen until the port is available (e.g. after a server shutdown).
var srvLn net.Listener
require.Eventually(t, func() bool {
var listenErr error
srvLn, listenErr = net.Listen("tcp", addr)
return listenErr == nil
}, 2*time.Second, 50*time.Millisecond, "port %s not available in time", addr)

srv := &http.Server{
Handler: gomcp.NewStreamableHTTPHandler(func(*http.Request) *gomcp.Server { return s }, nil),
}
go func() { _ = srv.Serve(srvLn) }()

return func() { _ = srv.Close() }
}

callPing := func(t *testing.T, ts *Toolset) string {
t.Helper()
result, callErr := ts.callTool(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{Name: "ping", Arguments: "{}"},
})
require.NoError(t, callErr)
return result.Output
}

// --- Step 1–2: Start first server, connect toolset ---
shutdown1 := startServer(t)

ts := NewRemoteToolset("test", fmt.Sprintf("http://%s/mcp", addr), "streamable-http", nil)
require.NoError(t, ts.Start(t.Context()))

toolList, err := ts.Tools(t.Context())
require.NoError(t, err)
require.Len(t, toolList, 1)
assert.Equal(t, "test_ping", toolList[0].Name)

// --- Step 3: Call succeeds on original server ---
assert.Equal(t, "pong-1", callPing(t, ts))

// --- Step 4: Shut down the server ---
shutdown1()

// Capture the current restarted channel before the reconnect
ts.mu.Lock()
restartedCh := ts.restarted
ts.mu.Unlock()

// --- Step 5–6: Start a fresh server, call again ---
shutdown2 := startServer(t)
t.Cleanup(func() {
_ = ts.Stop(t.Context())
shutdown2()
})

// This call triggers ErrSessionMissing recovery and must succeed transparently.
assert.Equal(t, "pong-2", callPing(t, ts))

// Verify that watchConnection actually restarted the connection by checking
// that the restarted channel was closed (signaling reconnect completion).
select {
case <-restartedCh:
// Success: the channel was closed, meaning reconnect happened
case <-time.After(100 * time.Millisecond):
t.Fatal("reconnect did not complete: restarted channel was not closed")
}
}
Loading