Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ These targets can be invoked via `make <target>` as needed during development an
## Pull request guidelines

- PR titles must start with a category prefix describing the change: `🐛 bug:`, `🔥 feat:`, `📒 docs:`, or `🧹 chore:`.
- Generated PR bodies should contain a **Summary** section that captures all changes included in the PR, not just the latest commit.
- Generated PR titles and bodies must summarize the *entire* set of changes on the branch (for example, based on `git log --oneline <base>..HEAD` or the full diff), **not** just the latest commit. The Summary section should reflect all modifications that will be merged.

## Programmatic checks

Expand All @@ -75,3 +75,7 @@ make test
```

All checks must pass before the generated code can be merged.

After completing the programmatic checks above, confirm that any relevant
documentation has been updated to reflect the changes made, including PR
instructions when applicable.
1 change: 1 addition & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ const (
HeaderTE = "TE"
HeaderTrailer = "Trailer"
HeaderTransferEncoding = "Transfer-Encoding"
HeaderSecFetchSite = "Sec-Fetch-Site"
HeaderSecWebSocketAccept = "Sec-WebSocket-Accept"
HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions"
HeaderSecWebSocketKey = "Sec-WebSocket-Key"
Expand Down
4 changes: 4 additions & 0 deletions docs/middleware/csrf.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ async function makeRequest(url, data) {

The middleware employs a robust, defense-in-depth strategy to protect against CSRF attacks. The primary defense is token-based validation, which operates in one of two modes depending on your configuration. This is supplemented by a mandatory secondary check on the request's origin.

### Fetch Metadata Guardrails

- **Sec-Fetch-Site**: For unsafe methods, the middleware inspects the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header when present. If the header value is not one of "same-origin", "none", "same-site", or "cross-site", the request is rejected with `ErrFetchSiteInvalid`. If the header is valid or absent, the request proceeds to the standard origin and token validation checks. This provides an early check to block requests with invalid `Sec-Fetch-Site` values, while allowing legitimate same-site and cross-site requests to be validated by the existing mechanisms.

Comment thread
gaby marked this conversation as resolved.
### 1. Token Validation Patterns

#### Double Submit Cookie (Default Mode)
Expand Down
2 changes: 2 additions & 0 deletions docs/whats_new.md
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,8 @@ The `Expiration` field in the CSRF middleware configuration has been renamed to

CSRF now redacts tokens and storage keys by default and exposes a `DisableValueRedaction` toggle (default `false`) if you must surface those values in diagnostics.

The CSRF middleware now validates the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header for unsafe HTTP methods. When present, requests with invalid `Sec-Fetch-Site` values (not one of "same-origin", "none", "same-site", or "cross-site") are rejected with `ErrFetchSiteInvalid`. Valid or absent headers proceed to standard origin and token validation checks, providing an early gate to catch malformed requests while maintaining compatibility with legitimate cross-site traffic.

### Idempotency

Idempotency middleware now redacts keys by default and offers a `DisableValueRedaction` configuration flag (default `false`) to expose them when debugging.
Expand Down
39 changes: 30 additions & 9 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ import (
)

var (
ErrTokenNotFound = errors.New("csrf: token not found")
ErrTokenInvalid = errors.New("csrf: token invalid")
ErrRefererNotFound = errors.New("csrf: referer header missing")
ErrRefererInvalid = errors.New("csrf: referer header invalid")
ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins")
ErrOriginInvalid = errors.New("csrf: origin header invalid")
ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins")
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value.
ErrTokenNotFound = errors.New("csrf: token not found")
ErrTokenInvalid = errors.New("csrf: token invalid")
ErrFetchSiteInvalid = errors.New("csrf: sec-fetch-site header invalid")
ErrRefererNotFound = errors.New("csrf: referer header missing")
ErrRefererInvalid = errors.New("csrf: referer header invalid")
ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins")
ErrOriginInvalid = errors.New("csrf: origin header invalid")
ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins")
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value.

)

Expand Down Expand Up @@ -127,6 +128,11 @@ func New(config ...Config) fiber.Handler {
default:
// Assume that anything not defined as 'safe' by RFC7231 needs protection

// Evaluate Sec-Fetch-Site to reject cross-site requests earlier when available.
if err := validateSecFetchSite(c); err != nil {
return cfg.ErrorHandler(c, err)
Comment thread
gaby marked this conversation as resolved.
}

// Enforce an origin check for unsafe requests.
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)

Expand Down Expand Up @@ -313,6 +319,21 @@ func (handler *Handler) DeleteToken(c fiber.Ctx) error {
return nil
}

func validateSecFetchSite(c fiber.Ctx) error {
secFetchSite := utils.Trim(c.Get(fiber.HeaderSecFetchSite), ' ')

if secFetchSite == "" {
return nil
}

switch utils.ToLower(secFetchSite) {
case "same-origin", "none", "cross-site", "same-site":
return nil
default:
return ErrFetchSiteInvalid
}
}
Comment thread
gaby marked this conversation as resolved.
Comment thread
gaby marked this conversation as resolved.

// originMatchesHost checks that the origin header matches the host header
// returns an error if the origin header is not present or is invalid
// returns nil if the origin header is valid
Expand Down
176 changes: 176 additions & 0 deletions middleware/csrf/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,182 @@ func Test_CSRF_Extractor_EmptyString(t *testing.T) {
require.Equal(t, ErrTokenNotFound.Error(), string(ctx.Response.Body()))
}

func Test_CSRF_SecFetchSite(t *testing.T) {
t.Parallel()

errorHandler := func(c fiber.Ctx, err error) error {
return c.Status(fiber.StatusForbidden).SendString(err.Error())
}

app := fiber.New()

app.Use(New(Config{ErrorHandler: errorHandler}))

app.All("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetHost("example.com")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]

tests := []struct {
name string
method string
secFetchSite string
origin string
expectedStatus int16
https bool
expectFetchSiteInvalid bool
}{
{
name: "same-origin allowed",
method: fiber.MethodPost,
secFetchSite: "same-origin",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "none allowed",
method: fiber.MethodPost,
secFetchSite: "none",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "cross-site with origin allowed",
method: fiber.MethodPost,
secFetchSite: "cross-site",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "same-site with origin allowed",
method: fiber.MethodPost,
secFetchSite: "same-site",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "cross-site with mismatched origin blocked",
method: fiber.MethodPost,
secFetchSite: "cross-site",
origin: "https://attacker.example",
expectedStatus: http.StatusForbidden,
},
{
name: "same-site with null origin blocked",
method: fiber.MethodPost,
secFetchSite: "same-site",
origin: "null",
expectedStatus: http.StatusForbidden,
https: true,
},
{
name: "invalid header blocked",
method: fiber.MethodPost,
secFetchSite: "weird",
origin: "http://example.com",
expectedStatus: http.StatusForbidden,
expectFetchSiteInvalid: true,
},
{
name: "no header with no origin",
method: fiber.MethodPost,
origin: "",
expectedStatus: http.StatusOK,
},
{
name: "no header with matching origin",
method: fiber.MethodPost,
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "no header with mismatched origin",
method: fiber.MethodPost,
origin: "https://attacker.example",
expectedStatus: http.StatusForbidden,
},
{
name: "no header with null origin",
method: fiber.MethodPost,
origin: "null",
expectedStatus: http.StatusForbidden,
https: true,
},
{
name: "GET allowed",
method: fiber.MethodGet,
secFetchSite: "cross-site",
expectedStatus: http.StatusOK,
},
{
name: "HEAD allowed",
method: fiber.MethodHead,
secFetchSite: "cross-site",
expectedStatus: http.StatusOK,
},
{
name: "OPTIONS allowed",
method: fiber.MethodOptions,
secFetchSite: "cross-site",
expectedStatus: http.StatusOK,
},
{
name: "PUT with mismatched origin blocked",
method: fiber.MethodPut,
secFetchSite: "cross-site",
origin: "https://attacker.example",
expectedStatus: http.StatusForbidden,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
c := &fasthttp.RequestCtx{}
scheme := "http"
if tt.https {
scheme = "https"
}
c.Request.Header.SetMethod(tt.method)
c.Request.URI().SetScheme(scheme)
c.Request.URI().SetHost("example.com")
c.Request.Header.SetHost("example.com")
c.Request.Header.SetProtocol(scheme)
if scheme == "https" {
c.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
}
if tt.origin != "" {
c.Request.Header.Set(fiber.HeaderOrigin, tt.origin)
}
if tt.secFetchSite != "" {
c.Request.Header.Set(fiber.HeaderSecFetchSite, tt.secFetchSite)
}

safe := tt.method == fiber.MethodGet || tt.method == fiber.MethodHead || tt.method == fiber.MethodOptions || tt.method == fiber.MethodTrace

if !safe {
c.Request.Header.Set(HeaderName, token)
c.Request.Header.SetCookie(ConfigDefault.CookieName, token)
}

h(c)
require.Equal(t, int(tt.expectedStatus), c.Response.StatusCode())
if tt.expectFetchSiteInvalid {
require.Equal(t, ErrFetchSiteInvalid.Error(), string(c.Response.Body()))
}
})
}
}

func Test_CSRF_Origin(t *testing.T) {
t.Parallel()
app := fiber.New()
Expand Down
Loading