Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions internal/mcp/connect_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,37 @@ func TestDefaultConnectTimeout_Value(t *testing.T) {
assert.Equal(t, 30*time.Second, defaultConnectTimeout,
"defaultConnectTimeout must remain 30 s to stay in sync with config.DefaultConnectTimeout")
}

func TestNormalizeConnectTimeout(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
input time.Duration
expected time.Duration
}{
{
name: "zero input uses default",
input: 0,
expected: defaultConnectTimeout,
},
{
name: "negative input uses default",
input: -1 * time.Second,
expected: defaultConnectTimeout,
},
{
name: "positive input is unchanged",
input: 15 * time.Second,
expected: 15 * time.Second,
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expected, normalizeConnectTimeout(tc.input))
})
}
}
44 changes: 17 additions & 27 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ var logConn = logger.New("mcp:connection")
// Kept in sync with config.DefaultConnectTimeout (30 s) to avoid importing the config package.
const defaultConnectTimeout = 30 * time.Second

// normalizeConnectTimeout returns defaultConnectTimeout when the input timeout is
// non-positive, otherwise it returns the input timeout unchanged.
func normalizeConnectTimeout(timeout time.Duration) time.Duration {
if timeout <= 0 {
return defaultConnectTimeout
}
return timeout
}

// ContextKey for session ID
type ContextKey string

Expand Down Expand Up @@ -203,9 +212,7 @@ func NewConnection(ctx context.Context, serverID, command string, args []string,
// This ensures compatibility with all types of HTTP MCP servers.
func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[string]string, oidcProvider *oidc.Provider, oidcAudience string, keepAlive time.Duration, connectTimeout time.Duration) (*Connection, error) {
// Apply default connect timeout when not specified
if connectTimeout <= 0 {
connectTimeout = defaultConnectTimeout
}
connectTimeout = normalizeConnectTimeout(connectTimeout)
logger.LogInfo("backend", "Creating HTTP MCP connection with transport fallback, url=%s, connectTimeout=%v", url, connectTimeout)
ctx, cancel := context.WithCancel(ctx)

Expand Down Expand Up @@ -383,10 +390,7 @@ func (c *Connection) reconnectSDKTransport() error {
return fmt.Errorf("cannot reconnect: unsupported transport type %s", c.httpTransportType)
}

timeout := c.connectTimeout
if timeout <= 0 {
timeout = defaultConnectTimeout
}
timeout := normalizeConnectTimeout(c.connectTimeout)
connectCtx, cancel := context.WithTimeout(c.ctx, timeout)
defer cancel()

Expand Down Expand Up @@ -468,29 +472,15 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
// For plain JSON-RPC transport, use manual HTTP requests
if c.httpTransportType == HTTPTransportPlainJSON {
result, err = c.sendHTTPRequest(ctx, method, params)
// Log the response from backend server
var responsePayload []byte
if result != nil {
responsePayload, _ = json.Marshal(result)
}
logInboundRPCResponse(serverID, responsePayload, err, shouldAttachAgentTags, snapshot)
return result, err
}

// For streamable and SSE transports, use SDK session methods
result, err = c.callSDKMethodWithReconnect(method, params)
// Log the response from backend server
var responsePayload []byte
if result != nil {
responsePayload, _ = json.Marshal(result)
} else {
// For streamable and SSE transports, use SDK session methods
result, err = c.callSDKMethodWithReconnect(method, params)
}
logInboundRPCResponse(serverID, responsePayload, err, shouldAttachAgentTags, snapshot)
return result, err
} else {
// Handle stdio connections using SDK client
result, err = c.callSDKMethod(method, params)
}

// Handle stdio connections using SDK client
result, err = c.callSDKMethod(method, params)

// Log the response from backend server
var responsePayload []byte
if result != nil {
Expand Down
Loading