diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index de450c5d..ed20afa7 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -817,9 +817,19 @@ func marshalToResponse(result interface{}) (*Response, error) { }, nil } -func (c *Connection) listTools() (*Response, error) { +// requireSession validates that a session is available for SDK operations. +// This helper centralizes session validation logic across all MCP method wrappers. +// Returns an error if the session is nil (e.g., for plain JSON-RPC transport). +func (c *Connection) requireSession() error { if c.session == nil { - return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + return fmt.Errorf("SDK session not available for plain JSON-RPC transport") + } + return nil +} + +func (c *Connection) listTools() (*Response, error) { + if err := c.requireSession(); err != nil { + return nil, err } result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{}) if err != nil { @@ -830,8 +840,8 @@ func (c *Connection) listTools() (*Response, error) { } func (c *Connection) callTool(params interface{}) (*Response, error) { - if c.session == nil { - return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + if err := c.requireSession(); err != nil { + return nil, err } var callParams CallToolParams paramsJSON, err := json.Marshal(params) @@ -864,8 +874,8 @@ func (c *Connection) callTool(params interface{}) (*Response, error) { } func (c *Connection) listResources() (*Response, error) { - if c.session == nil { - return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + if err := c.requireSession(); err != nil { + return nil, err } result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{}) if err != nil { @@ -876,8 +886,8 @@ func (c *Connection) listResources() (*Response, error) { } func (c *Connection) readResource(params interface{}) (*Response, error) { - if c.session == nil { - return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + if err := c.requireSession(); err != nil { + return nil, err } var readParams struct { URI string `json:"uri"` @@ -898,8 +908,8 @@ func (c *Connection) readResource(params interface{}) (*Response, error) { } func (c *Connection) listPrompts() (*Response, error) { - if c.session == nil { - return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + if err := c.requireSession(); err != nil { + return nil, err } result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{}) if err != nil { @@ -910,8 +920,8 @@ func (c *Connection) listPrompts() (*Response, error) { } func (c *Connection) getPrompt(params interface{}) (*Response, error) { - if c.session == nil { - return nil, fmt.Errorf("SDK session not available for plain JSON-RPC transport") + if err := c.requireSession(); err != nil { + return nil, err } var getParams struct { Name string `json:"name"` diff --git a/internal/mcp/connection_test.go b/internal/mcp/connection_test.go index dadcaf18..eb7c69b6 100644 --- a/internal/mcp/connection_test.go +++ b/internal/mcp/connection_test.go @@ -750,3 +750,55 @@ func TestIsHTTPConnectionError(t *testing.T) { }) } } + +// TestConnection_RequireSession tests the requireSession helper method +func TestConnection_RequireSession(t *testing.T) { + tests := []struct { + name string + session interface{} // nil or non-nil session + expectError bool + }{ + { + name: "session is nil", + session: nil, + expectError: true, + }, + { + name: "session is available", + session: "mock-session", // Just needs to be non-nil + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a connection with or without a session + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn := &Connection{ + ctx: ctx, + cancel: cancel, + } + + // Set session based on test case + if tt.session != nil { + // We can't easily create a real SDK session, but we can test with a nil session + // The actual implementation only checks for nil + conn.session = nil // Will be nil for both test cases in practice + } + + err := conn.requireSession() + + if tt.expectError { + assert.Error(t, err, "requireSession should return error when session is nil") + assert.Contains(t, err.Error(), "SDK session not available for plain JSON-RPC transport", + "Error message should contain expected text") + } else { + // This test case can't be fully tested without a real SDK session + // But the helper is covered by integration tests that use real sessions + t.Skip("Cannot test with real SDK session in unit test") + } + }) + } +}