Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions internal/cmd/flags_difc.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ func getDefaultAllowOnlyMinIntegrity() string {
return envutil.GetEnvString("MCP_GATEWAY_ALLOWONLY_MIN_INTEGRITY", defaultAllowOnlyMinIntegrity)
}

// ValidateDIFCMode validates the guards mode flag value and returns an error if invalid
func ValidateDIFCMode(mode string) error {
_, err := difc.ParseEnforcementMode(mode)
return err
}

func parseDIFCSinkServerIDs(input string) ([]string, error) {
if strings.TrimSpace(input) == "" {
return nil, nil
Expand Down
13 changes: 8 additions & 5 deletions internal/cmd/flags_difc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestValidateDIFCMode(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateDIFCMode(tt.mode)
_, err := difc.ParseEnforcementMode(tt.mode)
if tt.wantErr {
assert.Error(t, err, "expected error for mode %q", tt.mode)
} else {
Comment on lines 61 to 66
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test now validates difc.ParseEnforcementMode directly, but the test function name still references the removed ValidateDIFCMode helper. Renaming the test (and any related descriptions) will keep intent clear and avoid confusion when searching for ValidateDIFCMode usages.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -135,10 +135,13 @@ func TestGetDefaultDIFCMode(t *testing.T) {
func TestValidDIFCModes(t *testing.T) {
require := require.New(t)

// Verify all expected modes are valid using ValidateDIFCMode
require.NoError(ValidateDIFCMode(difc.ModeStrict), "strict should be valid")
require.NoError(ValidateDIFCMode(difc.ModeFilter), "filter should be valid")
require.NoError(ValidateDIFCMode(difc.ModePropagate), "propagate should be valid")
// Verify all expected modes are valid using difc.ParseEnforcementMode
_, err := difc.ParseEnforcementMode(difc.ModeStrict)
require.NoError(err, "strict should be valid")
_, err = difc.ParseEnforcementMode(difc.ModeFilter)
require.NoError(err, "filter should be valid")
_, err = difc.ParseEnforcementMode(difc.ModePropagate)
require.NoError(err, "propagate should be valid")

// Verify ValidModes slice has 3 entries
require.Len(difc.ValidModes, 3, "should only have 3 valid modes")
Expand Down
3 changes: 2 additions & 1 deletion internal/cmd/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"path/filepath"
"syscall"

"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/proxy"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -122,7 +123,7 @@ func runProxy(cmd *cobra.Command, args []string) error {

logProxyCmd.Printf("Starting proxy: listen=%s, guard=%s, mode=%s, tls=%v", proxyListen, proxyGuardWasm, proxyDIFCMode, proxyTLS)

if err := ValidateDIFCMode(proxyDIFCMode); err != nil {
if _, err := difc.ParseEnforcementMode(proxyDIFCMode); err != nil {
return fmt.Errorf("invalid --guards-mode flag: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func run(cmd *cobra.Command, args []string) error {
}

// Validate guards mode before applying
if err := ValidateDIFCMode(difcMode); err != nil {
if _, err := difc.ParseEnforcementMode(difcMode); err != nil {
return fmt.Errorf("invalid --guards-mode flag: %w", err)
}

Expand Down
16 changes: 16 additions & 0 deletions internal/httputil/httputil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Package httputil provides shared HTTP helper utilities used across multiple
// HTTP-facing packages (server, proxy, etc.).
package httputil

import (
"encoding/json"
"net/http"
)

// WriteJSONResponse sets the Content-Type header, writes the status code, and encodes
// body as JSON. It centralises the three-line pattern used across HTTP handlers.
func WriteJSONResponse(w http.ResponseWriter, statusCode int, body interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(body)
}
6 changes: 1 addition & 5 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ type AgentTagsSnapshot struct {
Integrity []string
}

func getAgentTagsSnapshotFromContext(ctx context.Context) (*AgentTagsSnapshot, bool) {
return GetAgentTagsSnapshotFromContext(ctx)
}

// GetAgentTagsSnapshotFromContext extracts the agent DIFC tag snapshot from the request context.
// Used by guards (e.g., write-sink) that need the agent's current labels to mirror onto resources.
func GetAgentTagsSnapshotFromContext(ctx context.Context) (*AgentTagsSnapshot, bool) {
Expand Down Expand Up @@ -371,7 +367,7 @@ func (c *Connection) SendRequest(method string, params interface{}) (*Response,
// SendRequestWithServerID sends a JSON-RPC request with server ID for logging
// The ctx parameter is used to extract session ID for HTTP MCP servers
func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, params interface{}, serverID string) (*Response, error) {
snapshot, hasSnapshot := getAgentTagsSnapshotFromContext(ctx)
snapshot, hasSnapshot := GetAgentTagsSnapshotFromContext(ctx)
shouldAttachAgentTags := hasSnapshot && difc.IsSinkServerID(serverID)

// Log the outbound request to backend server
Expand Down
17 changes: 5 additions & 12 deletions internal/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/github/gh-aw-mcpg/internal/difc"
"github.com/github/gh-aw-mcpg/internal/guard"
"github.com/github/gh-aw-mcpg/internal/httputil"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/strutil"
)
Expand All @@ -35,7 +36,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Health check endpoint
if rawPath == "/health" || rawPath == "/healthz" {
writeJSONResponse(w, http.StatusOK, map[string]string{"status": "ok"})
httputil.WriteJSONResponse(w, http.StatusOK, map[string]string{"status": "ok"})
return
}

Expand Down Expand Up @@ -67,7 +68,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if match == nil {
// Unknown GraphQL query — fail closed: deny rather than risk leaking unfiltered data
logHandler.Printf("unknown GraphQL query, blocking request: %s", strutil.Truncate(string(graphQLBody), 500))
writeJSONResponse(w, http.StatusForbidden, map[string]interface{}{
httputil.WriteJSONResponse(w, http.StatusForbidden, map[string]interface{}{
"errors": []map[string]string{{"message": "access denied: unrecognized GraphQL operation"}},
"data": nil,
})
Expand Down Expand Up @@ -151,7 +152,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
} else {
// Write blocked
logHandler.Printf("[DIFC] Phase 2: BLOCKED %s %s — %s", r.Method, path, evalResult.Reason)
writeJSONResponse(w, http.StatusForbidden, map[string]string{
httputil.WriteJSONResponse(w, http.StatusForbidden, map[string]string{
"message": fmt.Sprintf("DIFC policy violation: %s", evalResult.Reason),
})
return
Expand Down Expand Up @@ -224,7 +225,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
// Strict mode: block entire response if any item filtered
if s.enforcementMode == difc.EnforcementStrict && filtered.GetFilteredCount() > 0 {
logHandler.Printf("[DIFC] STRICT: blocking response — %d filtered items", filtered.GetFilteredCount())
writeJSONResponse(w, http.StatusForbidden, map[string]string{
httputil.WriteJSONResponse(w, http.StatusForbidden, map[string]string{
"message": fmt.Sprintf("DIFC policy violation: %d of %d items not accessible",
filtered.GetFilteredCount(), filtered.TotalCount),
})
Expand Down Expand Up @@ -351,14 +352,6 @@ func (h *proxyHandler) writeEmptyResponse(w http.ResponseWriter, resp *http.Resp
w.Write([]byte(empty))
}

// writeJSONResponse sets the Content-Type header, writes the status code, and encodes
// body as JSON. It centralises the three-line pattern used across HTTP handlers.
func writeJSONResponse(w http.ResponseWriter, statusCode int, body interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(body)
}

// forwardAndReadBody forwards a request to the upstream GitHub API and reads the
// entire response body. On success it returns the response and body bytes. It writes
// a 502 error to w and returns nil, nil on failure.
Expand Down
6 changes: 2 additions & 4 deletions internal/server/http_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package server
import (
"bytes"
"context"
"encoding/json"
"io"
"log"
"net/http"

"github.com/github/gh-aw-mcpg/internal/auth"
"github.com/github/gh-aw-mcpg/internal/guard"
"github.com/github/gh-aw-mcpg/internal/httputil"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/logger/sanitize"
"github.com/github/gh-aw-mcpg/internal/mcp"
Expand All @@ -20,9 +20,7 @@ var logHelpers = logger.New("server:helpers")
// writeJSONResponse sets the Content-Type header, writes the status code, and encodes
// body as JSON. It centralises the three-line pattern used across HTTP handlers.
func writeJSONResponse(w http.ResponseWriter, statusCode int, body interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(body)
httputil.WriteJSONResponse(w, statusCode, body)
}

// withResponseLogging wraps an http.Handler to log response bodies
Expand Down
Loading