diff --git a/docs/PROXY_MODE.md b/docs/PROXY_MODE.md index 1a56a2fa..ab04a9c2 100644 --- a/docs/PROXY_MODE.md +++ b/docs/PROXY_MODE.md @@ -89,14 +89,19 @@ The proxy reuses the same 6-phase pipeline as the MCP gateway, with Phase 3 adap ## REST Route Mapping -The proxy maps ~25 GitHub REST API URL patterns to guard tool names: +The proxy maps REST API URL patterns to guard tool names (see `internal/proxy/router.go` for the exact source of truth). Inbound paths are normalized first: + +- `GH_HOST` style REST paths with `/api/v3/...` are normalized to `/...` for routing. +- Query strings are ignored for route matching and still forwarded upstream. + +Supported path families include: | URL Pattern | Guard Tool | |-------------|-----------| | `/repos/:owner/:repo/issues` | `list_issues` | -| `/repos/:owner/:repo/issues/:number` | `get_issue` | +| `/repos/:owner/:repo/issues/:number` | `issue_read` | | `/repos/:owner/:repo/pulls` | `list_pull_requests` | -| `/repos/:owner/:repo/pulls/:number` | `get_pull_request` | +| `/repos/:owner/:repo/pulls/:number` | `pull_request_read` | | `/repos/:owner/:repo/commits` | `list_commits` | | `/repos/:owner/:repo/commits/:sha` | `get_commit` | | `/repos/:owner/:repo/contents/:path` | `get_file_contents` | @@ -106,21 +111,34 @@ The proxy maps ~25 GitHub REST API URL patterns to guard tool names: | `/search/code` | `search_code` | | `/search/repositories` | `search_repositories` | | `/user` | `get_me` | -| ... | See `internal/proxy/router.go` for full list | +| `/notifications` | `list_notifications` | +| `/orgs/:owner/actions/(secrets|variables)[/:name]` | `actions_list` | +| `/repos/:owner/:repo/discussions...` | `list_discussions` / `get_discussion_comments` | +| `/repos/:owner/:repo/...` (fallback) | `get_file_contents` | +| ... | See `internal/proxy/router.go` for the complete regex list and precedence | -Unrecognized URLs pass through without DIFC filtering. +For **read operations** (GET and GraphQL POST), unmatched routes are denied (fail-closed) to avoid accidental unfiltered data exposure. For **write operations** (non-read methods), requests pass through unchanged. ## GraphQL Support -GraphQL queries to `/graphql` are parsed to extract the operation type and owner/repo context: +Inbound GraphQL endpoint paths accepted by the proxy: + +- `/graphql` (github.com style) +- `/api/graphql` (GHES style used by `gh` when host is GHES/proxy) +- `/api/v3/graphql` (GH_HOST prefix style; normalized) + +GraphQL queries are parsed to extract operation type and owner/repo context: - **Repository-scoped queries** (issues, PRs, commits) — mapped to corresponding tool names - **Search queries** — mapped to `search_issues` or `search_code` - **Viewer queries** — mapped to `get_me` -- **Unknown queries** — passed through without filtering +- **Schema introspection (`__schema`, `__type`)** — passed through (safe metadata) +- **Unknown queries** — denied (fail-closed) Owner and repo are extracted from GraphQL variables (`$owner`, `$name`/`$repo`) or inline string arguments. +When the upstream API base is GHES-style `.../api/v3`, GraphQL forwarding is rewritten to `.../api/graphql` to match GHES routing. + ## Policy Notes - **Repo names must be lowercase** in policies (e.g., `octocat/hello-world` not `octocat/Hello-World`). The guard performs case-insensitive matching against actual GitHub data. diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 956c9b2c..94bc3678 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -101,7 +101,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if match.ToolName == "graphql_introspection" { logHandler.Printf("GraphQL introspection query, passing through") clientAuth := r.Header.Get("Authorization") - resp, respBody := h.forwardAndReadBody(w, r.Context(), http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth) + resp, respBody := h.forwardAndReadBody(w, r.Context(), http.MethodPost, fullPath, bytes.NewReader(graphQLBody), "application/json", clientAuth) if resp == nil { return } @@ -206,7 +206,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa oteltrace.WithSpanKind(oteltrace.SpanKindClient), ) if graphQLBody != nil { - resp, respBody = h.forwardAndReadBody(w, fwdCtx, http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth) + resp, respBody = h.forwardAndReadBody(w, fwdCtx, http.MethodPost, path, bytes.NewReader(graphQLBody), "application/json", clientAuth) } else { resp, respBody = h.forwardAndReadBody(w, fwdCtx, r.Method, path, nil, "", clientAuth) } diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index 76b3d219..ce43e420 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -188,6 +188,47 @@ func TestServeHTTP_GraphQLIntrospectionPassthrough(t *testing.T) { assert.Contains(t, w.Body.String(), "__schema") } +func TestServeHTTP_GraphQLPreservesQueryString(t *testing.T) { + tests := []struct { + name string + path string + wantPath string + }{ + {name: "graphql path", path: "/graphql?foo=bar", wantPath: "/graphql?foo=bar"}, + {name: "ghes api graphql path", path: "/api/graphql?foo=bar", wantPath: "/api/graphql?foo=bar"}, + {name: "gh host prefixed graphql path", path: "/api/v3/graphql?foo=bar", wantPath: "/graphql?foo=bar"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var receivedURL string + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedURL = r.URL.RequestURI() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{"data":{"repository":{"issues":{"nodes":[]}}}}`)) + require.NoError(t, err) + })) + defer upstream.Close() + + s := newTestServer(t, upstream.URL) + h := &proxyHandler{server: s} + + gqlBody, err := json.Marshal(map[string]interface{}{ + "query": `{ repository(owner:"org", name:"repo") { issues(first: 10) { nodes { id } } } }`, + }) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, tt.path, bytes.NewReader(gqlBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, tt.wantPath, receivedURL) + }) + } +} + // ─── ServeHTTP: query string is forwarded on REST GET ──────────────────────── func TestServeHTTP_QueryStringForwardedToUpstream(t *testing.T) { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 3b8cb41b..130ff49d 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -343,6 +343,13 @@ func (r *restBackendCaller) CallTool(ctx context.Context, toolName string, args // if non-empty it is forwarded as-is, otherwise the configured fallback token is used. func (s *Server) forwardToGitHub(ctx context.Context, method, path string, body io.Reader, contentType string, clientAuth string) (*http.Response, error) { url := s.githubAPIURL + path + pathOnly, query, hasQuery := strings.Cut(path, "?") + if strings.HasSuffix(s.githubAPIURL, "/api/v3") && IsGraphQLPath(pathOnly) { + url = strings.TrimSuffix(s.githubAPIURL, "/api/v3") + "/api/graphql" + if hasQuery { + url += "?" + query + } + } logProxy.Printf("forwarding %s %s → %s", method, path, url) req, err := http.NewRequestWithContext(ctx, method, url, body) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 1c69dc0c..f929d2ea 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "io" "net/http" "net/http/httptest" @@ -844,3 +845,45 @@ func TestUnwrapSingleObject(t *testing.T) { }) } } +func TestForwardToGitHub_RewritesGraphQLPathForGHESAPIBase(t *testing.T) { + var receivedPath string + var receivedQuery string + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedPath = r.URL.Path + receivedQuery = r.URL.RawQuery + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + s := &Server{ + githubAPIURL: upstream.URL + "/api/v3", + httpClient: upstream.Client(), + } + + tests := []struct { + name string + path string + wantPath string + wantQuery string + }{ + {name: "plain graphql endpoint", path: "/graphql", wantPath: "/api/graphql"}, + {name: "graphql endpoint with query string", path: "/graphql?foo=bar", wantPath: "/api/graphql", wantQuery: "foo=bar"}, + {name: "ghes api graphql endpoint", path: "/api/graphql", wantPath: "/api/graphql"}, + {name: "gh host prefixed graphql endpoint with query string", path: "/api/v3/graphql?foo=bar", wantPath: "/api/graphql", wantQuery: "foo=bar"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + receivedPath = "" + receivedQuery = "" + + resp, err := s.forwardToGitHub(context.Background(), http.MethodPost, tt.path, nil, "application/json", "") + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + assert.Equal(t, tt.wantPath, receivedPath) + assert.Equal(t, tt.wantQuery, receivedQuery) + }) + } +}