Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pkg/tools/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool
return nil, fmt.Errorf("failed to parse tool arguments: %w", err)
}

// Strip null values from arguments. Some models (e.g. OpenAI) send explicit
// null for optional parameters, but MCP servers may reject them because
// null is not a valid value for the declared parameter type (e.g. string).
// Omitting the key is semantically equivalent to null for optional params.
for k, v := range args {
if v == nil {
delete(args, k)
}
}

request := &mcp.CallToolParams{}
request.Name = toolCall.Function.Name
request.Arguments = args
Expand Down
105 changes: 105 additions & 0 deletions pkg/tools/mcp/mcp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package mcp

import (
"context"
"iter"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/tools"
)

// mockMCPClient is a test double for the mcpClient interface.
type mockMCPClient struct {
callToolFn func(ctx context.Context, request *mcp.CallToolParams) (*mcp.CallToolResult, error)
}

func (m *mockMCPClient) Initialize(context.Context, *mcp.InitializeRequest) (*mcp.InitializeResult, error) {
return &mcp.InitializeResult{}, nil
}

func (m *mockMCPClient) ListTools(context.Context, *mcp.ListToolsParams) iter.Seq2[*mcp.Tool, error] {
return func(func(*mcp.Tool, error) bool) {}
}

func (m *mockMCPClient) CallTool(ctx context.Context, request *mcp.CallToolParams) (*mcp.CallToolResult, error) {
return m.callToolFn(ctx, request)
}

func (m *mockMCPClient) ListPrompts(context.Context, *mcp.ListPromptsParams) iter.Seq2[*mcp.Prompt, error] {
return func(func(*mcp.Prompt, error) bool) {}
}

func (m *mockMCPClient) GetPrompt(context.Context, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {
return &mcp.GetPromptResult{}, nil
}

func (m *mockMCPClient) SetElicitationHandler(tools.ElicitationHandler) {}

func (m *mockMCPClient) SetOAuthSuccessHandler(func()) {}

func (m *mockMCPClient) SetManagedOAuth(bool) {}

func (m *mockMCPClient) Close(context.Context) error { return nil }

func TestCallToolStripsNullArguments(t *testing.T) {
t.Parallel()

tests := []struct {
name string
arguments string
expectedArgs map[string]any
}{
{
name: "all null values are stripped",
arguments: `{"dir": null, "pattern": null}`,
expectedArgs: map[string]any{},
},
{
name: "only null values are stripped",
arguments: `{"dir": ".", "pattern": null, "extra": "value"}`,
expectedArgs: map[string]any{"dir": ".", "extra": "value"},
},
{
name: "empty arguments stay empty",
arguments: `{}`,
expectedArgs: map[string]any{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var capturedArgs map[string]any

ts := &Toolset{
started: true,
mcpClient: &mockMCPClient{
callToolFn: func(_ context.Context, request *mcp.CallToolParams) (*mcp.CallToolResult, error) {
if m, ok := request.Arguments.(map[string]any); ok {
capturedArgs = m
}
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{Text: "ok"}},
}, nil
},
},
}

result, err := ts.callTool(t.Context(), tools.ToolCall{
Function: tools.FunctionCall{
Name: "test_tool",
Arguments: tt.arguments,
},
})

require.NoError(t, err)
assert.Equal(t, "ok", result.Output)
assert.Equal(t, tt.expectedArgs, capturedArgs)
})
}
}