diff --git a/AGENTS.md b/AGENTS.md index ed237756805..cf81d32820d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -56,7 +56,7 @@ These targets can be invoked via `make ` as needed during development an ## Pull request guidelines - PR titles must start with a category prefix describing the change: `๐Ÿ› bug:`, `๐Ÿ”ฅ feat:`, `๐Ÿ“’ docs:`, or `๐Ÿงน chore:`. -- Generated PR bodies should contain a **Summary** section that captures all changes included in the PR, not just the latest commit. +- Generated PR titles and bodies must summarize the *entire* set of changes on the branch (for example, based on `git log --oneline ..HEAD` or the full diff), **not** just the latest commit. The Summary section should reflect all modifications that will be merged. ## Programmatic checks @@ -75,3 +75,7 @@ make test ``` All checks must pass before the generated code can be merged. + +After completing the programmatic checks above, confirm that any relevant +documentation has been updated to reflect the changes made, including PR +instructions when applicable. diff --git a/constants.go b/constants.go index a93b2413180..5f93fe1cd4b 100644 --- a/constants.go +++ b/constants.go @@ -256,6 +256,7 @@ const ( HeaderTE = "TE" HeaderTrailer = "Trailer" HeaderTransferEncoding = "Transfer-Encoding" + HeaderSecFetchSite = "Sec-Fetch-Site" HeaderSecWebSocketAccept = "Sec-WebSocket-Accept" HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions" HeaderSecWebSocketKey = "Sec-WebSocket-Key" diff --git a/docs/middleware/csrf.md b/docs/middleware/csrf.md index c5d7cb39881..162f8a84eeb 100644 --- a/docs/middleware/csrf.md +++ b/docs/middleware/csrf.md @@ -171,6 +171,10 @@ async function makeRequest(url, data) { The middleware employs a robust, defense-in-depth strategy to protect against CSRF attacks. The primary defense is token-based validation, which operates in one of two modes depending on your configuration. This is supplemented by a mandatory secondary check on the request's origin. +### Fetch Metadata Guardrails + +- **Sec-Fetch-Site**: For unsafe methods, the middleware inspects the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header when present. If the header value is not one of "same-origin", "none", "same-site", or "cross-site", the request is rejected with `ErrFetchSiteInvalid`. If the header is valid or absent, the request proceeds to the standard origin and token validation checks. This provides an early check to block requests with invalid `Sec-Fetch-Site` values, while allowing legitimate same-site and cross-site requests to be validated by the existing mechanisms. + ### 1. Token Validation Patterns #### Double Submit Cookie (Default Mode) diff --git a/docs/whats_new.md b/docs/whats_new.md index 1ffaa70f82b..2a09e7f5278 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -1286,6 +1286,8 @@ The `Expiration` field in the CSRF middleware configuration has been renamed to CSRF now redacts tokens and storage keys by default and exposes a `DisableValueRedaction` toggle (default `false`) if you must surface those values in diagnostics. +The CSRF middleware now validates the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header for unsafe HTTP methods. When present, requests with invalid `Sec-Fetch-Site` values (not one of "same-origin", "none", "same-site", or "cross-site") are rejected with `ErrFetchSiteInvalid`. Valid or absent headers proceed to standard origin and token validation checks, providing an early gate to catch malformed requests while maintaining compatibility with legitimate cross-site traffic. + ### Idempotency Idempotency middleware now redacts keys by default and offers a `DisableValueRedaction` configuration flag (default `false`) to expose them when debugging. diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index 28a287bbcb2..98a4da7b8ae 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -14,15 +14,16 @@ import ( ) var ( - ErrTokenNotFound = errors.New("csrf: token not found") - ErrTokenInvalid = errors.New("csrf: token invalid") - ErrRefererNotFound = errors.New("csrf: referer header missing") - ErrRefererInvalid = errors.New("csrf: referer header invalid") - ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins") - ErrOriginInvalid = errors.New("csrf: origin header invalid") - ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins") - errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user - dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value. + ErrTokenNotFound = errors.New("csrf: token not found") + ErrTokenInvalid = errors.New("csrf: token invalid") + ErrFetchSiteInvalid = errors.New("csrf: sec-fetch-site header invalid") + ErrRefererNotFound = errors.New("csrf: referer header missing") + ErrRefererInvalid = errors.New("csrf: referer header invalid") + ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins") + ErrOriginInvalid = errors.New("csrf: origin header invalid") + ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins") + errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user + dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value. ) @@ -127,6 +128,11 @@ func New(config ...Config) fiber.Handler { default: // Assume that anything not defined as 'safe' by RFC7231 needs protection + // Evaluate Sec-Fetch-Site to reject cross-site requests earlier when available. + if err := validateSecFetchSite(c); err != nil { + return cfg.ErrorHandler(c, err) + } + // Enforce an origin check for unsafe requests. err := originMatchesHost(c, trustedOrigins, trustedSubOrigins) @@ -313,6 +319,21 @@ func (handler *Handler) DeleteToken(c fiber.Ctx) error { return nil } +func validateSecFetchSite(c fiber.Ctx) error { + secFetchSite := utils.Trim(c.Get(fiber.HeaderSecFetchSite), ' ') + + if secFetchSite == "" { + return nil + } + + switch utils.ToLower(secFetchSite) { + case "same-origin", "none", "cross-site", "same-site": + return nil + default: + return ErrFetchSiteInvalid + } +} + // originMatchesHost checks that the origin header matches the host header // returns an error if the origin header is not present or is invalid // returns nil if the origin header is valid diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 646ffc695a3..812dc8959e3 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -823,6 +823,182 @@ func Test_CSRF_Extractor_EmptyString(t *testing.T) { require.Equal(t, ErrTokenNotFound.Error(), string(ctx.Response.Body())) } +func Test_CSRF_SecFetchSite(t *testing.T) { + t.Parallel() + + errorHandler := func(c fiber.Ctx, err error) error { + return c.Status(fiber.StatusForbidden).SendString(err.Error()) + } + + app := fiber.New() + + app.Use(New(Config{ErrorHandler: errorHandler})) + + app.All("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.URI().SetScheme("http") + ctx.Request.URI().SetHost("example.com") + ctx.Request.Header.SetHost("example.com") + h(ctx) + token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) + token = strings.Split(strings.Split(token, ";")[0], "=")[1] + + tests := []struct { + name string + method string + secFetchSite string + origin string + expectedStatus int16 + https bool + expectFetchSiteInvalid bool + }{ + { + name: "same-origin allowed", + method: fiber.MethodPost, + secFetchSite: "same-origin", + origin: "http://example.com", + expectedStatus: http.StatusOK, + }, + { + name: "none allowed", + method: fiber.MethodPost, + secFetchSite: "none", + origin: "http://example.com", + expectedStatus: http.StatusOK, + }, + { + name: "cross-site with origin allowed", + method: fiber.MethodPost, + secFetchSite: "cross-site", + origin: "http://example.com", + expectedStatus: http.StatusOK, + }, + { + name: "same-site with origin allowed", + method: fiber.MethodPost, + secFetchSite: "same-site", + origin: "http://example.com", + expectedStatus: http.StatusOK, + }, + { + name: "cross-site with mismatched origin blocked", + method: fiber.MethodPost, + secFetchSite: "cross-site", + origin: "https://attacker.example", + expectedStatus: http.StatusForbidden, + }, + { + name: "same-site with null origin blocked", + method: fiber.MethodPost, + secFetchSite: "same-site", + origin: "null", + expectedStatus: http.StatusForbidden, + https: true, + }, + { + name: "invalid header blocked", + method: fiber.MethodPost, + secFetchSite: "weird", + origin: "http://example.com", + expectedStatus: http.StatusForbidden, + expectFetchSiteInvalid: true, + }, + { + name: "no header with no origin", + method: fiber.MethodPost, + origin: "", + expectedStatus: http.StatusOK, + }, + { + name: "no header with matching origin", + method: fiber.MethodPost, + origin: "http://example.com", + expectedStatus: http.StatusOK, + }, + { + name: "no header with mismatched origin", + method: fiber.MethodPost, + origin: "https://attacker.example", + expectedStatus: http.StatusForbidden, + }, + { + name: "no header with null origin", + method: fiber.MethodPost, + origin: "null", + expectedStatus: http.StatusForbidden, + https: true, + }, + { + name: "GET allowed", + method: fiber.MethodGet, + secFetchSite: "cross-site", + expectedStatus: http.StatusOK, + }, + { + name: "HEAD allowed", + method: fiber.MethodHead, + secFetchSite: "cross-site", + expectedStatus: http.StatusOK, + }, + { + name: "OPTIONS allowed", + method: fiber.MethodOptions, + secFetchSite: "cross-site", + expectedStatus: http.StatusOK, + }, + { + name: "PUT with mismatched origin blocked", + method: fiber.MethodPut, + secFetchSite: "cross-site", + origin: "https://attacker.example", + expectedStatus: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c := &fasthttp.RequestCtx{} + scheme := "http" + if tt.https { + scheme = "https" + } + c.Request.Header.SetMethod(tt.method) + c.Request.URI().SetScheme(scheme) + c.Request.URI().SetHost("example.com") + c.Request.Header.SetHost("example.com") + c.Request.Header.SetProtocol(scheme) + if scheme == "https" { + c.Request.Header.Set(fiber.HeaderXForwardedProto, "https") + } + if tt.origin != "" { + c.Request.Header.Set(fiber.HeaderOrigin, tt.origin) + } + if tt.secFetchSite != "" { + c.Request.Header.Set(fiber.HeaderSecFetchSite, tt.secFetchSite) + } + + safe := tt.method == fiber.MethodGet || tt.method == fiber.MethodHead || tt.method == fiber.MethodOptions || tt.method == fiber.MethodTrace + + if !safe { + c.Request.Header.Set(HeaderName, token) + c.Request.Header.SetCookie(ConfigDefault.CookieName, token) + } + + h(c) + require.Equal(t, int(tt.expectedStatus), c.Response.StatusCode()) + if tt.expectFetchSiteInvalid { + require.Equal(t, ErrFetchSiteInvalid.Error(), string(c.Response.Body())) + } + }) + } +} + func Test_CSRF_Origin(t *testing.T) { t.Parallel() app := fiber.New()