Skip to content

Commit c8a9dce

Browse files
authored
feat(mcp): add tool registry to handle tools from different endpoints (#1237)
* add tool registry registry tracks tools and from which endpoints they originated from * unexport DoToolCall and DecodeToolResponse * use tool registry in mcp command for tool execution * fix MCP casing
1 parent 97c3004 commit c8a9dce

File tree

3 files changed

+106
-35
lines changed

3 files changed

+106
-35
lines changed

cmd/src/mcp.go

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88
"strings"
99

10-
"github.com/sourcegraph/src-cli/internal/api"
1110
"github.com/sourcegraph/src-cli/internal/mcp"
1211

1312
"github.com/sourcegraph/sourcegraph/lib/errors"
@@ -36,8 +35,8 @@ func mcpMain(args []string) error {
3635
apiClient := cfg.apiClient(nil, mcpFlagSet.Output())
3736

3837
ctx := context.Background()
39-
tools, err := mcp.FetchToolDefinitions(ctx, apiClient)
40-
if err != nil {
38+
registry := mcp.NewToolRegistry()
39+
if err := registry.LoadTools(ctx, apiClient); err != nil {
4140
return err
4241
}
4342

@@ -49,7 +48,7 @@ func mcpMain(args []string) error {
4948
subcmd := args[0]
5049
if subcmd == "list-tools" {
5150
fmt.Println("The following tools are available:")
52-
for name := range tools {
51+
for name := range registry.All() {
5352
fmt.Printf(" %s\n", name)
5453
}
5554
fmt.Println("\nUSAGE:")
@@ -58,7 +57,7 @@ func mcpMain(args []string) error {
5857
fmt.Println(" src mcp <tool-name> -h List the available flags of a tool")
5958
return nil
6059
}
61-
tool, ok := tools[subcmd]
60+
tool, ok := registry.Get(subcmd)
6261
if !ok {
6362
return errors.Newf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd)
6463
}
@@ -81,7 +80,17 @@ func mcpMain(args []string) error {
8180
return err
8281
}
8382

84-
return handleMcpTool(context.Background(), apiClient, tool, vars)
83+
result, err := registry.CallTool(ctx, apiClient, tool.Name, vars)
84+
if err != nil {
85+
return err
86+
}
87+
88+
output, err := json.MarshalIndent(result, "", " ")
89+
if err != nil {
90+
return err
91+
}
92+
fmt.Println(string(output))
93+
return nil
8594
}
8695

8796
func printSchemas(tool *mcp.ToolDef) error {
@@ -111,23 +120,3 @@ func validateToolArgs(inputSchema mcp.SchemaObject, args []string, vars map[stri
111120

112121
return nil
113122
}
114-
115-
func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.ToolDef, vars map[string]any) error {
116-
resp, err := mcp.DoToolCall(ctx, client, tool.RawName, vars)
117-
if err != nil {
118-
return err
119-
}
120-
121-
result, err := mcp.DecodeToolResponse(resp)
122-
if err != nil {
123-
return err
124-
}
125-
defer resp.Body.Close()
126-
127-
output, err := json.MarshalIndent(result, "", " ")
128-
if err != nil {
129-
return err
130-
}
131-
fmt.Println(string(output))
132-
return nil
133-
}

internal/mcp/mcp_request.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ import (
1313
"github.com/sourcegraph/sourcegraph/lib/errors"
1414
)
1515

16-
const McpURLPath = ".api/mcp/v1"
16+
const MCPURLPath = ".api/mcp/v1"
17+
const MCPDeepSearchURLPath = ".api/mcp/deepsearch"
1718

18-
func FetchToolDefinitions(ctx context.Context, client api.Client) (map[string]*ToolDef, error) {
19-
resp, err := doJSONRPC(ctx, client, "tools/list", nil)
19+
func fetchToolDefinitions(ctx context.Context, client api.Client, endpoint string) (map[string]*ToolDef, error) {
20+
resp, err := doJSONRPC(ctx, client, endpoint, "tools/list", nil)
2021
if err != nil {
2122
return nil, errors.Wrap(err, "failed to list tools from mcp endpoint")
2223
}
@@ -44,7 +45,7 @@ func FetchToolDefinitions(ctx context.Context, client api.Client) (map[string]*T
4445
return loadToolDefinitions(rpcResp.Result)
4546
}
4647

47-
func DoToolCall(ctx context.Context, client api.Client, tool string, vars map[string]any) (*http.Response, error) {
48+
func doToolCall(ctx context.Context, client api.Client, endpoint string, tool string, vars map[string]any) (*http.Response, error) {
4849
params := struct {
4950
Name string `json:"name"`
5051
Arguments map[string]any `json:"arguments"`
@@ -53,10 +54,10 @@ func DoToolCall(ctx context.Context, client api.Client, tool string, vars map[st
5354
Arguments: vars,
5455
}
5556

56-
return doJSONRPC(ctx, client, "tools/call", params)
57+
return doJSONRPC(ctx, client, endpoint, "tools/call", params)
5758
}
5859

59-
func doJSONRPC(ctx context.Context, client api.Client, method string, params any) (*http.Response, error) {
60+
func doJSONRPC(ctx context.Context, client api.Client, endpoint string, method string, params any) (*http.Response, error) {
6061
jsonRPC := struct {
6162
Version string `json:"jsonrpc"`
6263
ID int `json:"id"`
@@ -75,7 +76,7 @@ func doJSONRPC(ctx context.Context, client api.Client, method string, params any
7576
}
7677
buf.Write(data)
7778

78-
req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpURLPath, buf)
79+
req, err := client.NewHTTPRequest(ctx, http.MethodPost, endpoint, buf)
7980
if err != nil {
8081
return nil, err
8182
}
@@ -91,13 +92,13 @@ func doJSONRPC(ctx context.Context, client api.Client, method string, params any
9192
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
9293
resp.Body.Close()
9394
return nil, errors.Newf("MCP endpoint %s returned %d: %s",
94-
McpURLPath, resp.StatusCode, strings.TrimSpace(string(body)))
95+
endpoint, resp.StatusCode, strings.TrimSpace(string(body)))
9596
}
9697

9798
return resp, nil
9899
}
99100

100-
func DecodeToolResponse(resp *http.Response) (map[string]json.RawMessage, error) {
101+
func decodeToolResponse(resp *http.Response) (map[string]json.RawMessage, error) {
101102
data, err := readSSEResponseData(resp)
102103
if err != nil {
103104
return nil, err

internal/mcp/registry.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package mcp
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"iter"
7+
8+
"github.com/sourcegraph/src-cli/internal/api"
9+
10+
"github.com/sourcegraph/sourcegraph/lib/errors"
11+
)
12+
13+
// ToolRegistry keeps track of tools and the endpoints they originated from
14+
type ToolRegistry struct {
15+
tools map[string]*ToolDef
16+
endpoints map[string]string
17+
}
18+
19+
func NewToolRegistry() *ToolRegistry {
20+
return &ToolRegistry{
21+
tools: make(map[string]*ToolDef),
22+
endpoints: make(map[string]string),
23+
}
24+
}
25+
26+
// LoadTools loads the tool definitions from the Mcp tool endpoints constants McpURLPath and McpDeepSearchURLPath
27+
func (r *ToolRegistry) LoadTools(ctx context.Context, client api.Client) error {
28+
endpoints := []string{MCPURLPath, MCPDeepSearchURLPath}
29+
30+
var errs []error
31+
for _, endpoint := range endpoints {
32+
tools, err := fetchToolDefinitions(ctx, client, endpoint)
33+
if err != nil {
34+
errs = append(errs, errors.Wrapf(err, "failed to load tools from %s", endpoint))
35+
continue
36+
}
37+
r.register(endpoint, tools)
38+
}
39+
40+
if len(errs) > 0 {
41+
return errors.Append(nil, errs...)
42+
}
43+
return nil
44+
}
45+
46+
// register associates a collection of tools with the given endpoint
47+
func (r *ToolRegistry) register(endpoint string, tools map[string]*ToolDef) {
48+
for name, def := range tools {
49+
r.tools[name] = def
50+
r.endpoints[name] = endpoint
51+
}
52+
}
53+
54+
// Get returns the tool definition for the given name
55+
func (r *ToolRegistry) Get(name string) (*ToolDef, bool) {
56+
tool, ok := r.tools[name]
57+
return tool, ok
58+
}
59+
60+
// CallTool calls the given tool with the given arguments. It constructs the Tool request and decodes the Tool response
61+
func (r *ToolRegistry) CallTool(ctx context.Context, client api.Client, name string, args map[string]any) (map[string]json.RawMessage, error) {
62+
tool := r.tools[name]
63+
endpoint := r.endpoints[name]
64+
resp, err := doToolCall(ctx, client, endpoint, tool.RawName, args)
65+
if err != nil {
66+
return nil, err
67+
}
68+
defer resp.Body.Close()
69+
return decodeToolResponse(resp)
70+
}
71+
72+
// All returns an iterator that yields the name and Tool definition of all registered tools
73+
func (r *ToolRegistry) All() iter.Seq2[string, *ToolDef] {
74+
return func(yield func(string, *ToolDef) bool) {
75+
for name, def := range r.tools {
76+
if !yield(name, def) {
77+
return
78+
}
79+
}
80+
}
81+
}

0 commit comments

Comments
 (0)