diff --git a/internal/mcp/connect_timeout_test.go b/internal/mcp/connect_timeout_test.go index c5509353..7e313ac9 100644 --- a/internal/mcp/connect_timeout_test.go +++ b/internal/mcp/connect_timeout_test.go @@ -98,3 +98,37 @@ func TestDefaultConnectTimeout_Value(t *testing.T) { assert.Equal(t, 30*time.Second, defaultConnectTimeout, "defaultConnectTimeout must remain 30 s to stay in sync with config.DefaultConnectTimeout") } + +func TestNormalizeConnectTimeout(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input time.Duration + expected time.Duration + }{ + { + name: "zero input uses default", + input: 0, + expected: defaultConnectTimeout, + }, + { + name: "negative input uses default", + input: -1 * time.Second, + expected: defaultConnectTimeout, + }, + { + name: "positive input is unchanged", + input: 15 * time.Second, + expected: 15 * time.Second, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, normalizeConnectTimeout(tc.input)) + }) + } +} diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 8ff4b35a..05953f8a 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -29,6 +29,15 @@ var logConn = logger.New("mcp:connection") // Kept in sync with config.DefaultConnectTimeout (30 s) to avoid importing the config package. const defaultConnectTimeout = 30 * time.Second +// normalizeConnectTimeout returns defaultConnectTimeout when the input timeout is +// non-positive, otherwise it returns the input timeout unchanged. +func normalizeConnectTimeout(timeout time.Duration) time.Duration { + if timeout <= 0 { + return defaultConnectTimeout + } + return timeout +} + // ContextKey for session ID type ContextKey string @@ -203,9 +212,7 @@ func NewConnection(ctx context.Context, serverID, command string, args []string, // This ensures compatibility with all types of HTTP MCP servers. func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[string]string, oidcProvider *oidc.Provider, oidcAudience string, keepAlive time.Duration, connectTimeout time.Duration) (*Connection, error) { // Apply default connect timeout when not specified - if connectTimeout <= 0 { - connectTimeout = defaultConnectTimeout - } + connectTimeout = normalizeConnectTimeout(connectTimeout) logger.LogInfo("backend", "Creating HTTP MCP connection with transport fallback, url=%s, connectTimeout=%v", url, connectTimeout) ctx, cancel := context.WithCancel(ctx) @@ -383,10 +390,7 @@ func (c *Connection) reconnectSDKTransport() error { return fmt.Errorf("cannot reconnect: unsupported transport type %s", c.httpTransportType) } - timeout := c.connectTimeout - if timeout <= 0 { - timeout = defaultConnectTimeout - } + timeout := normalizeConnectTimeout(c.connectTimeout) connectCtx, cancel := context.WithTimeout(c.ctx, timeout) defer cancel() @@ -468,29 +472,15 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, // For plain JSON-RPC transport, use manual HTTP requests if c.httpTransportType == HTTPTransportPlainJSON { result, err = c.sendHTTPRequest(ctx, method, params) - // Log the response from backend server - var responsePayload []byte - if result != nil { - responsePayload, _ = json.Marshal(result) - } - logInboundRPCResponse(serverID, responsePayload, err, shouldAttachAgentTags, snapshot) - return result, err - } - - // For streamable and SSE transports, use SDK session methods - result, err = c.callSDKMethodWithReconnect(method, params) - // Log the response from backend server - var responsePayload []byte - if result != nil { - responsePayload, _ = json.Marshal(result) + } else { + // For streamable and SSE transports, use SDK session methods + result, err = c.callSDKMethodWithReconnect(method, params) } - logInboundRPCResponse(serverID, responsePayload, err, shouldAttachAgentTags, snapshot) - return result, err + } else { + // Handle stdio connections using SDK client + result, err = c.callSDKMethod(method, params) } - // Handle stdio connections using SDK client - result, err = c.callSDKMethod(method, params) - // Log the response from backend server var responsePayload []byte if result != nil {