From 0c0ac794a443626d3a1517f8154f109f06a26e68 Mon Sep 17 00:00:00 2001 From: David Ahmann Date: Sun, 15 Mar 2026 08:30:42 -0400 Subject: [PATCH] http: classify auth and scope failures (#2213) --- docs/error-handling.md | 7 +++ pkg/errors/error.go | 22 +++++++ pkg/errors/error_test.go | 20 ++++++ pkg/http/middleware/auth_error.go | 23 +++++++ pkg/http/middleware/scope_challenge.go | 9 ++- pkg/http/middleware/scope_challenge_test.go | 68 +++++++++++++++++++++ pkg/http/middleware/token.go | 14 +++-- pkg/http/middleware/token_test.go | 18 ++++++ 8 files changed, 175 insertions(+), 6 deletions(-) create mode 100644 pkg/http/middleware/auth_error.go create mode 100644 pkg/http/middleware/scope_challenge_test.go diff --git a/docs/error-handling.md b/docs/error-handling.md index 9bb27e0fa..6ad7d15a0 100644 --- a/docs/error-handling.md +++ b/docs/error-handling.md @@ -20,11 +20,18 @@ Used for REST API errors from the GitHub API: ```go type GitHubAPIError struct { Message string `json:"message"` + Code string `json:"code"` Response *github.Response `json:"-"` Err error `json:"-"` } ``` +For HTTP-auth related failures, `Code` is populated with a machine-readable classifier so callers and middleware can distinguish: + +- `missing_token`: no bearer token was provided +- `invalid_token`: the token was malformed or GitHub rejected it with `401` +- `insufficient_scope`: the token was valid but lacked the required permission or scope + ### GitHubGraphQLError Used for GraphQL API errors from the GitHub API: diff --git a/pkg/errors/error.go b/pkg/errors/error.go index d75765159..85eed218f 100644 --- a/pkg/errors/error.go +++ b/pkg/errors/error.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v82/github" @@ -12,6 +13,7 @@ import ( type GitHubAPIError struct { Message string `json:"message"` + Code string `json:"code,omitempty"` Response *github.Response `json:"-"` Err error `json:"-"` } @@ -20,6 +22,7 @@ type GitHubAPIError struct { func newGitHubAPIError(message string, resp *github.Response, err error) *GitHubAPIError { return &GitHubAPIError{ Message: message, + Code: classifyHTTPErrorCode(resp.Response, message), Response: resp, Err: err, } @@ -47,6 +50,7 @@ func (e *GitHubGraphQLError) Error() string { type GitHubRawAPIError struct { Message string `json:"message"` + Code string `json:"code,omitempty"` Response *http.Response `json:"-"` Err error `json:"-"` } @@ -54,11 +58,29 @@ type GitHubRawAPIError struct { func newGitHubRawAPIError(message string, resp *http.Response, err error) *GitHubRawAPIError { return &GitHubRawAPIError{ Message: message, + Code: classifyHTTPErrorCode(resp, message), Response: resp, Err: err, } } +func classifyHTTPErrorCode(resp *http.Response, message string) string { + if resp == nil { + return "" + } + + switch resp.StatusCode { + case http.StatusUnauthorized: + return "invalid_token" + case http.StatusForbidden: + if strings.Contains(strings.ToLower(message), "scope") || strings.Contains(strings.ToLower(message), "permission") { + return "insufficient_scope" + } + } + + return "" +} + func (e *GitHubRawAPIError) Error() string { return fmt.Errorf("%s: %w", e.Message, e.Err).Error() } diff --git a/pkg/errors/error_test.go b/pkg/errors/error_test.go index e33d5bd39..edfb1ab61 100644 --- a/pkg/errors/error_test.go +++ b/pkg/errors/error_test.go @@ -36,6 +36,7 @@ func TestGitHubErrorContext(t *testing.T) { apiError := apiErrors[0] assert.Equal(t, "failed to fetch resource", apiError.Message) + assert.Empty(t, apiError.Code) assert.Equal(t, resp, apiError.Response) assert.Equal(t, originalErr, apiError.Err) assert.Equal(t, "failed to fetch resource: resource not found", apiError.Error()) @@ -86,6 +87,7 @@ func TestGitHubErrorContext(t *testing.T) { rawError := rawErrors[0] assert.Equal(t, "failed to fetch raw content", rawError.Message) + assert.Empty(t, rawError.Code) assert.Equal(t, resp, rawError.Response) assert.Equal(t, originalErr, rawError.Err) }) @@ -361,6 +363,24 @@ func TestGitHubErrorContext(t *testing.T) { assert.NoError(t, err, "NewGitHubAPIErrorToCtx should handle nil context gracefully") assert.Nil(t, updatedCtx, "Context should remain nil when passed as nil") }) + + t.Run("API errors classify invalid token and insufficient scope codes from HTTP status", func(t *testing.T) { + ctx := ContextWithGitHubErrors(context.Background()) + + unauthorized := &github.Response{Response: &http.Response{StatusCode: http.StatusUnauthorized}} + forbidden := &github.Response{Response: &http.Response{StatusCode: http.StatusForbidden}} + + _, err := NewGitHubAPIErrorToCtx(ctx, "token rejected", unauthorized, fmt.Errorf("unauthorized")) + require.NoError(t, err) + _, err = NewGitHubAPIErrorToCtx(ctx, "insufficient permissions", forbidden, fmt.Errorf("forbidden")) + require.NoError(t, err) + + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + require.Len(t, apiErrors, 2) + assert.Equal(t, "invalid_token", apiErrors[0].Code) + assert.Equal(t, "insufficient_scope", apiErrors[1].Code) + }) } func TestGitHubErrorTypes(t *testing.T) { diff --git a/pkg/http/middleware/auth_error.go b/pkg/http/middleware/auth_error.go new file mode 100644 index 000000000..9c9b6f8aa --- /dev/null +++ b/pkg/http/middleware/auth_error.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "encoding/json" + "net/http" +) + +type authErrorBody struct { + Error string `json:"error"` + Code string `json:"code"` +} + +func writeAuthError(w http.ResponseWriter, status int, code string, message string, wwwAuthenticate string) { + if wwwAuthenticate != "" { + w.Header().Set("WWW-Authenticate", wwwAuthenticate) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(authErrorBody{ + Error: message, + Code: code, + }) +} diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index 1a86bf93c..ee4ca8746 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -137,8 +137,13 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter ) // Send scope challenge response with the superset of existing and required scopes - w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader) - http.Error(w, "Forbidden: insufficient scopes", http.StatusForbidden) + writeAuthError( + w, + http.StatusForbidden, + "insufficient_scope", + "Forbidden: insufficient scopes", + wwwAuthenticateHeader, + ) } return http.HandlerFunc(fn) } diff --git a/pkg/http/middleware/scope_challenge_test.go b/pkg/http/middleware/scope_challenge_test.go new file mode 100644 index 000000000..5c5352f1f --- /dev/null +++ b/pkg/http/middleware/scope_challenge_test.go @@ -0,0 +1,68 @@ +package middleware + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/oauth" + "github.com/github/github-mcp-server/pkg/scopes" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type scopeChallengeFetcher struct { + scopes []string + err error +} + +func (m *scopeChallengeFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) { + return m.scopes, m.err +} + +func TestWithScopeChallenge_ReturnsMachineReadableInsufficientScopeCode(t *testing.T) { + scopes.SetGlobalToolScopeMap(scopes.ToolScopeMap{ + "create_or_update_file": { + RequiredScopes: []string{"repo"}, + AcceptedScopes: []string{"repo"}, + }, + }) + t.Cleanup(func() { + scopes.SetGlobalToolScopeMap(nil) + }) + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := WithScopeChallenge( + &oauth.Config{BaseURL: "https://example.com"}, + &scopeChallengeFetcher{scopes: []string{}}, + )(nextHandler) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req = req.WithContext(ghcontext.WithTokenInfo(req.Context(), &ghcontext.TokenInfo{ + Token: "gho_test", + TokenType: utils.TokenTypeOAuthAccessToken, + })) + req = req.WithContext(ghcontext.WithMCPMethodInfo(req.Context(), &ghcontext.MCPMethodInfo{ + Method: "tools/call", + ItemName: "create_or_update_file", + })) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Header().Get("WWW-Authenticate"), `error="insufficient_scope"`) + + var body struct { + Code string `json:"code"` + } + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + assert.Equal(t, "insufficient_scope", body.Code) +} diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index 012bbabef..ff04d8819 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -30,8 +30,9 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl sendAuthChallenge(w, r, oauthCfg) return } - // For other auth errors (bad format, unsupported), return 400 - http.Error(w, err.Error(), http.StatusBadRequest) + // For other auth errors (bad format, unsupported), keep the existing 400 + // but expose a machine-readable invalid_token classification. + writeAuthError(w, http.StatusBadRequest, "invalid_token", err.Error(), "") return } @@ -51,6 +52,11 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) { resourcePath := oauth.ResolveResourcePath(r, oauthCfg) resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath) - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL)) - http.Error(w, "Unauthorized", http.StatusUnauthorized) + writeAuthError( + w, + http.StatusUnauthorized, + "missing_token", + "Unauthorized", + fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL), + ) } diff --git a/pkg/http/middleware/token_test.go b/pkg/http/middleware/token_test.go index fa8f0ee98..d63a7f4d3 100644 --- a/pkg/http/middleware/token_test.go +++ b/pkg/http/middleware/token_test.go @@ -1,6 +1,7 @@ package middleware import ( + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -23,6 +24,7 @@ func TestExtractUserToken(t *testing.T) { name string authHeader string expectedStatusCode int + expectedCode string expectedTokenType utils.TokenType expectedToken string expectTokenInfo bool @@ -33,6 +35,7 @@ func TestExtractUserToken(t *testing.T) { name: "missing Authorization header returns 401 with WWW-Authenticate", authHeader: "", expectedStatusCode: http.StatusUnauthorized, + expectedCode: "missing_token", expectTokenInfo: false, expectWWWAuth: true, }, @@ -151,18 +154,21 @@ func TestExtractUserToken(t *testing.T) { name: "unsupported GitHub-Bearer header returns 400", authHeader: "GitHub-Bearer some_encrypted_token", expectedStatusCode: http.StatusBadRequest, + expectedCode: "invalid_token", expectTokenInfo: false, }, { name: "invalid token format returns 400", authHeader: "Bearer invalid_token_format", expectedStatusCode: http.StatusBadRequest, + expectedCode: "invalid_token", expectTokenInfo: false, }, { name: "unrecognized prefix returns 400", authHeader: "Bearer xyz_notavalidprefix", expectedStatusCode: http.StatusBadRequest, + expectedCode: "invalid_token", expectTokenInfo: false, }, } @@ -189,6 +195,13 @@ func TestExtractUserToken(t *testing.T) { handler.ServeHTTP(rr, req) assert.Equal(t, tt.expectedStatusCode, rr.Code) + if tt.expectedCode != "" { + var body struct { + Code string `json:"code"` + } + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + assert.Equal(t, tt.expectedCode, body.Code) + } if tt.expectWWWAuth { wwwAuth := rr.Header().Get("WWW-Authenticate") @@ -253,6 +266,11 @@ func TestExtractUserToken_MissingAuthHeader_WWWAuthenticateFormat(t *testing.T) handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusUnauthorized, rr.Code) + var body struct { + Code string `json:"code"` + } + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + assert.Equal(t, "missing_token", body.Code) wwwAuth := rr.Header().Get("WWW-Authenticate") assert.NotEmpty(t, wwwAuth) assert.Contains(t, wwwAuth, "Bearer")