diff --git a/docs/rough_edges.md b/docs/rough_edges.md index e0e6d730..7ee18b1a 100644 --- a/docs/rough_edges.md +++ b/docs/rough_edges.md @@ -59,3 +59,7 @@ v2. wrapper) we need to first unmarshal into a `map[string]any` in order to do server-side validation of required fields. CallToolParams could have just had a map[string]any. + +- `StreamableHTTPOptions.CrossOriginProtection` should not have been part of + the SDK API. Cross-origin protection is a general HTTP concern, not specific + to MCP, and can be applied as standard HTTP middleware. diff --git a/internal/docs/rough_edges.src.md b/internal/docs/rough_edges.src.md index 42e79f78..d573e141 100644 --- a/internal/docs/rough_edges.src.md +++ b/internal/docs/rough_edges.src.md @@ -58,3 +58,7 @@ v2. wrapper) we need to first unmarshal into a `map[string]any` in order to do server-side validation of required fields. CallToolParams could have just had a map[string]any. + +- `StreamableHTTPOptions.CrossOriginProtection` should not have been part of + the SDK API. Cross-origin protection is a general HTTP concern, not specific + to MCP, and can be applied as standard HTTP middleware. diff --git a/mcp/sse.go b/mcp/sse.go index f8f156d8..0e1ad79e 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "fmt" "io" + "mime" "net" "net/http" "net/url" @@ -64,14 +65,6 @@ type SSEOptions struct { // Only disable this if you understand the security implications. // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise DisableLocalhostProtection bool - - // CrossOriginProtection allows to customize cross-origin protection. - // The deny handler set in the CrossOriginProtection through SetDenyHandler - // is ignored. - // If nil, default (zero-value) cross-origin protection will be used. - // Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter - // to disable the default protection until v1.7.0. - CrossOriginProtection *http.CrossOriginProtection } // NewSSEHandler returns a new [SSEHandler] that creates and manages MCP @@ -97,10 +90,6 @@ func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptio s.opts = *opts } - if s.opts.CrossOriginProtection == nil { - s.opts.CrossOriginProtection = &http.CrossOriginProtection{} - } - return s } @@ -212,20 +201,13 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } - if disablecrossoriginprotection != "1" { - // Verify the 'Origin' header to protect against CSRF attacks. - if err := h.opts.CrossOriginProtection.Check(req); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) + // Validate 'Content-Type' header. + if req.Method == http.MethodPost { + mediaType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type")) + if err != nil || mediaType != "application/json" { + http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) return } - // Validate 'Content-Type' header. - if req.Method == http.MethodPost { - contentType := req.Header.Get("Content-Type") - if contentType != "application/json" { - http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) - return - } - } } sessionID := req.URL.Query().Get("sessionid") diff --git a/mcp/sse_test.go b/mcp/sse_test.go index fe230a51..86b2cf1b 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -12,7 +12,6 @@ import ( "net" "net/http" "net/http/httptest" - "strings" "sync/atomic" "testing" @@ -320,77 +319,3 @@ func TestSSELocalhostProtection(t *testing.T) { }) } } - -func TestSSEOriginProtection(t *testing.T) { - server := NewServer(testImpl, nil) - - tests := []struct { - name string - protection *http.CrossOriginProtection - requestOrigin string - wantStatusCode int - }{ - { - name: "default protection with Origin header", - protection: nil, - requestOrigin: "https://example.com", - wantStatusCode: http.StatusForbidden, - }, - { - name: "custom protection with trusted origin and same Origin", - protection: func() *http.CrossOriginProtection { - p := http.NewCrossOriginProtection() - if err := p.AddTrustedOrigin("https://example.com"); err != nil { - t.Fatal(err) - } - return p - }(), - requestOrigin: "https://example.com", - wantStatusCode: http.StatusNotFound, // origin accepted; session not found - }, - { - name: "custom protection with trusted origin and different Origin", - protection: func() *http.CrossOriginProtection { - p := http.NewCrossOriginProtection() - if err := p.AddTrustedOrigin("https://example.com"); err != nil { - t.Fatal(err) - } - return p - }(), - requestOrigin: "https://malicious.com", - wantStatusCode: http.StatusForbidden, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - opts := &SSEOptions{ - CrossOriginProtection: tt.protection, - } - handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts) - httpServer := httptest.NewServer(handler) - defer httpServer.Close() - - // Use POST with a valid session-like URL to test origin protection - // without creating a hanging GET connection. - reqReader := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping"}`) - req, err := http.NewRequest(http.MethodPost, httpServer.URL+"?sessionid=nonexistent", reqReader) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Origin", tt.requestOrigin) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if got := resp.StatusCode; got != tt.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Errorf("Status code: got %d, want %d (body: %s)", got, tt.wantStatusCode, body) - } - }) - } -} diff --git a/mcp/streamable.go b/mcp/streamable.go index 8deb6c93..1455238f 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -180,9 +180,15 @@ type StreamableHTTPOptions struct { // CrossOriginProtection allows to customize cross-origin protection. // The deny handler set in the CrossOriginProtection through SetDenyHandler // is ignored. - // If nil, default (zero-value) cross-origin protection will be used. - // Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter - // to disable the default protection until v1.7.0. + // If nil, no cross-origin protection is applied. Use the `enableoriginverification` + // MCPGODEBUG compatibility parameter to enable the default protection until v1.8.0. + // + // Deprecated: wrap the handler with cross-origin protection middleware + // instead. For example: + // + // handler := mcp.NewStreamableHTTPHandler(...) + // protection := http.NewCrossOriginProtection() + // protectedHandler := protection.Handler(handler) CrossOriginProtection *http.CrossOriginProtection } @@ -202,7 +208,7 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea h.opts.Logger = ensureLogger(h.opts.Logger) - if h.opts.CrossOriginProtection == nil { + if h.opts.CrossOriginProtection == nil && enableoriginverification == "1" { h.opts.CrossOriginProtection = &http.CrossOriginProtection{} } @@ -235,15 +241,16 @@ func (h *StreamableHTTPHandler) closeAll() { // disablelocalhostprotection is a compatibility parameter that allows to disable // DNS rebinding protection, which was added in the 1.4.0 version of the SDK. // See the documentation for the mcpgodebug package for instructions how to enable it. -// The option will be removed in the 1.7.0 version of the SDK. +// The option will be removed in the 1.6.0 version of the SDK. var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection") -// disablecrossoriginprotection is a compatibility parameter that allows to disable -// the verification of the 'Origin' and 'Content-Type' headers, which was added in -// the 1.4.1 version of the SDK. See the documentation for the mcpgodebug package -// for instructions how to enable it. -// The option will be removed in the 1.7.0 version of the SDK. -var disablecrossoriginprotection = mcpgodebug.Value("disablecrossoriginprotection") +// enableoriginverification is a compatibility parameter that restores the +// default cross-origin protection behavior from v1.4.1-v1.5.0. When set to +// "1", a zero-value CrossOriginProtection will be applied if none is +// explicitly provided in StreamableHTTPOptions. +// See the documentation for the mcpgodebug package for instructions how to enable it. +// The option will be removed in the 1.8.0 version of the SDK. +var enableoriginverification = mcpgodebug.Value("enableoriginverification") func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // DNS rebinding protection: auto-enabled for localhost servers. @@ -257,17 +264,18 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } } - if disablecrossoriginprotection != "1" { + if h.opts.CrossOriginProtection != nil { // Verify the 'Origin' header to protect against CSRF attacks. if err := h.opts.CrossOriginProtection.Check(req); err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } - // Validate 'Content-Type' header. - if req.Method == http.MethodPost && baseMediaType(req.Header.Get("Content-Type")) != "application/json" { - http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) - return - } + } + + // Validate 'Content-Type' header. + if req.Method == http.MethodPost && baseMediaType(req.Header.Get("Content-Type")) != "application/json" { + http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) + return } // Allow multiple 'Accept' headers. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 592981fc..83326218 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2587,10 +2587,10 @@ func TestStreamableOriginProtection(t *testing.T) { wantStatusCode int }{ { - name: "default protection with Origin header", + name: "no protection with Origin header", protection: nil, requestOrigin: "https://example.com", - wantStatusCode: http.StatusForbidden, + wantStatusCode: http.StatusOK, }, { name: "custom protection with trusted origin and same Origin",