diff --git a/addon/retry/exponential_backoff.go b/addon/retry/exponential_backoff.go index 6740c35971f..199831bea0e 100644 --- a/addon/retry/exponential_backoff.go +++ b/addon/retry/exponential_backoff.go @@ -45,7 +45,7 @@ func (e *ExponentialBackoff) Retry(f func() error) error { e.currentInterval = e.InitialInterval } var err error - for i := 0; i < e.MaxRetryCount; i++ { + for range e.MaxRetryCount { err = f() if err == nil { return nil diff --git a/addon/retry/exponential_backoff_test.go b/addon/retry/exponential_backoff_test.go index 844ed0df94c..63dd193526e 100644 --- a/addon/retry/exponential_backoff_test.go +++ b/addon/retry/exponential_backoff_test.go @@ -107,7 +107,7 @@ func Test_ExponentialBackoff_Next(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - for i := 0; i < tt.expBackoff.MaxRetryCount; i++ { + for i := range tt.expBackoff.MaxRetryCount { next := tt.expBackoff.next() if next < tt.expNextTimeIntervals[i] || next > tt.expNextTimeIntervals[i]+1*time.Second { t.Errorf("wrong next time:\n"+ diff --git a/bind_test.go b/bind_test.go index 2246505dd2d..f89f0b4b6b9 100644 --- a/bind_test.go +++ b/bind_test.go @@ -149,7 +149,7 @@ func Test_Bind_Query_Map(t *testing.T) { em := make(map[string][]int) c.Request().URI().SetQueryString("") - require.ErrorIs(t, c.Bind().Query(&em), binder.ErrMapNotConvertable) + require.ErrorIs(t, c.Bind().Query(&em), binder.ErrMapNotConvertible) } // go test -run Test_Bind_Query_WithSetParserDecoder -v diff --git a/binder/binder.go b/binder/binder.go index 06c7c926a50..14c4bde3605 100644 --- a/binder/binder.go +++ b/binder/binder.go @@ -8,7 +8,7 @@ import ( // Binder errors var ( ErrSuitableContentNotFound = errors.New("binder: suitable content not found to parse body") - ErrMapNotConvertable = errors.New("binder: map is not convertable to map[string]string or map[string][]string") + ErrMapNotConvertible = errors.New("binder: map is not convertible to map[string]string or map[string][]string") ) var HeaderBinderPool = sync.Pool{ diff --git a/binder/mapping.go b/binder/mapping.go index 7cebff5c7b5..6bd8809c730 100644 --- a/binder/mapping.go +++ b/binder/mapping.go @@ -113,14 +113,14 @@ func parseToMap(ptr any, data map[string][]string) error { case reflect.Slice: newMap, ok := ptr.(map[string][]string) if !ok { - return ErrMapNotConvertable + return ErrMapNotConvertible } maps.Copy(newMap, data) case reflect.String, reflect.Interface: newMap, ok := ptr.(map[string]string) if !ok { - return ErrMapNotConvertable + return ErrMapNotConvertible } for k, v := range data { diff --git a/binder/mapping_test.go b/binder/mapping_test.go index f337ec435e5..11cf8c54f65 100644 --- a/binder/mapping_test.go +++ b/binder/mapping_test.go @@ -152,7 +152,7 @@ func Test_parseToMap(t *testing.T) { // Test map[string]any m3 := make(map[string]any) err = parseToMap(m3, inputMap) - require.ErrorIs(t, err, ErrMapNotConvertable) + require.ErrorIs(t, err, ErrMapNotConvertible) } func Test_FilterFlags(t *testing.T) { diff --git a/ctx.go b/ctx.go index 31361ff1122..26060be4d64 100644 --- a/ctx.go +++ b/ctx.go @@ -1751,21 +1751,35 @@ func (c *DefaultCtx) setCanonical(key, val string) { c.fasthttp.Response.Header.SetCanonical(utils.UnsafeBytes(key), utils.UnsafeBytes(val)) } -// Subdomains returns a string slice of subdomains in the domain name of the request. -// The subdomain offset, which defaults to 2, is used for determining the beginning of the subdomain segments. +// Subdomains returns a slice of subdomains from the host, excluding the last `offset` components. +// If the offset is negative or exceeds the number of subdomains, an empty slice is returned. +// If the offset is zero every label (no trimming) is returned. func (c *DefaultCtx) Subdomains(offset ...int) []string { o := 2 if len(offset) > 0 { o = offset[0] } - subdomains := strings.Split(c.Host(), ".") - l := len(subdomains) - o - // Check index to avoid slice bounds out of range panic - if l < 0 { - l = len(subdomains) + + // Negative offset, return nothing. + if o < 0 { + return []string{} + } + + // strip “:port” if present + host := c.Hostname() + parts := strings.Split(host, ".") + + // offset == 0, caller wants everything. + if o == 0 { + return parts } - subdomains = subdomains[:l] - return subdomains + + // If we trim away the whole slice (or more), nothing remains. + if o >= len(parts) { + return []string{} + } + + return parts[:len(parts)-o] } // Stale is not implemented yet, pull requests are welcome! diff --git a/ctx_test.go b/ctx_test.go index 87eddd46413..ce083d6e111 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -3015,15 +3015,88 @@ func Test_Ctx_Stale(t *testing.T) { // go test -run Test_Ctx_Subdomains func Test_Ctx_Subdomains(t *testing.T) { - t.Parallel() app := New() - c := app.AcquireCtx(&fasthttp.RequestCtx{}) - c.Request().URI().SetHost("john.doe.is.awesome.google.com") - require.Equal(t, []string{"john", "doe"}, c.Subdomains(4)) + type tc struct { + name string + host string + offset []int // nil ⇒ call without argument + want []string + } - c.Request().URI().SetHost("localhost:3000") - require.Equal(t, []string{"localhost:3000"}, c.Subdomains()) + cases := []tc{ + { + name: "default offset (2) drops registrable domain + TLD", + host: "john.doe.is.awesome.google.com", + offset: nil, // Subdomains() + want: []string{"john", "doe", "is", "awesome"}, + }, + { + name: "custom offset trims N right-hand labels", + host: "john.doe.is.awesome.google.com", + offset: []int{4}, + want: []string{"john", "doe"}, + }, + { + name: "offset too high returns empty", + host: "john.doe.is.awesome.google.com", + offset: []int{10}, + want: []string{}, + }, + { + name: "zero offset returns all labels", + host: "john.doe.google.com", + offset: []int{0}, + want: []string{"john", "doe", "google", "com"}, + }, + { + name: "offset 1 keeps registrable domain", + host: "john.doe.google.com", + offset: []int{1}, + want: []string{"john", "doe", "google"}, + }, + { + name: "negative offset returns empty", + host: "john.doe.google.com", + offset: []int{-1}, + want: []string{}, + }, + { + name: "offset equal len returns empty", + host: "john.doe.com", + offset: []int{3}, + want: []string{}, + }, + { + name: "offset equal len returns empty", + host: "john.doe.com", + offset: []int{3}, + want: []string{}, + }, + { + name: "zero offset returns all labels with port present", + host: "localhost:3000", + offset: []int{0}, + want: []string{"localhost"}, + }, + { + name: "host with port — custom offset trims 2 labels", + host: "foo.bar.example.com:8080", + offset: []int{2}, + want: []string{"foo", "bar"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(c) + + c.Request().URI().SetHost(tc.host) + got := c.Subdomains(tc.offset...) + require.Equal(t, tc.want, got) + }) + } } // go test -v -run=^$ -bench=Benchmark_Ctx_Subdomains -benchmem -count=4 diff --git a/docs/api/bind.md b/docs/api/bind.md index f81a370d939..4b4d6660628 100644 --- a/docs/api/bind.md +++ b/docs/api/bind.md @@ -538,7 +538,7 @@ For more control over error handling, you can use the following methods. If you want to handle binder errors automatically, you can use `WithAutoHandling`. If there's an error, it will return the error and set HTTP status to `400 Bad Request`. -This function does NOT panic therefor you must still return on error explicitly +This function does NOT panic therefore you must still return on error explicitly ```go title="Signature" func (b *Bind) WithAutoHandling() *Bind diff --git a/docs/api/ctx.md b/docs/api/ctx.md index 438ba591e1a..0a4de67d7c6 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -1344,21 +1344,33 @@ func (c fiber.Ctx) Stale() bool ### Subdomains -Returns a slice of subdomains in the domain name of the request. +Returns a slice with the host’s sub-domain labels. The dot-separated parts that precede the registrable domain (`example`) and the top-level domain (ex: `com`). -The application property `subdomain offset`, which defaults to `2`, is used for determining the beginning of the subdomain segments. +The `subdomain offset` (default `2`) tells Fiber how many labels, counting from the right-hand side, are always discarded. +Passing an `offset` argument lets you override that value for a single call. -```go title="Signature" +```go func (c fiber.Ctx) Subdomains(offset ...int) []string ``` -```go title="Example" +| `offset` | Result | Meaning | +|----------|----------------------------------------|-------------------------------------------------------| +| *omitted* → **2** | trim 2 right-most labels | drop the registrable domain **and** the TLD | +| `1` to `len(labels)-1` | trim exactly `offset` right-most labels | custom trimming of available labels | +| `>= len(labels)` | **return `[]`** | offset exceeds available labels → empty slice | +| `0` | **return every label** | keep the entire host unchanged | +| `< 0` | **return `[]`** | negative offsets are invalid → empty slice | + +#### Example + +```go // Host: "tobi.ferrets.example.com" app.Get("/", func(c fiber.Ctx) error { - c.Subdomains() // ["ferrets", "tobi"] - c.Subdomains(1) // ["tobi"] - + c.Subdomains() // ["tobi", "ferrets"] + c.Subdomains(1) // ["tobi", "ferrets", "example"] + c.Subdomains(0) // ["tobi", "ferrets", "example", "com"] + c.Subdomains(-1) // [] // ... }) ``` @@ -1586,8 +1598,8 @@ app.Get("/", func(c fiber.Ctx) error { Transfers the file from the given path as an `attachment`. -Typically, browsers will prompt the user to download. By default, the [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header `filename=` parameter is the file path (_this typically appears in the browser dialog_). -Override this default with the **filename** parameter. +Typically, browsers will prompt the user to download. By default, the [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header `filename=` parameter is the file path (this typically appears in the browser dialog). +Override this default with the `filename` parameter. ```go title="Signature" func (c fiber.Ctx) Download(file string, filename ...string) error diff --git a/docs/guide/utils.md b/docs/guide/utils.md index 1f3fec1c696..a5231383903 100644 --- a/docs/guide/utils.md +++ b/docs/guide/utils.md @@ -13,7 +13,7 @@ Converts a string value to a specified type, handling errors and optional defaul This function simplifies the conversion process by encapsulating error handling and the management of default values, making your code cleaner and more consistent. ```go title="Signature" -func Convert[T any](value string, convertor func(string) (T, error), defaultValue ...T) (*T, error) +func Convert[T any](value string, converter func(string) (T, error), defaultValue ...T) (*T, error) ``` ```go title="Example" diff --git a/docs/middleware/cors.md b/docs/middleware/cors.md index 1c5124a9923..a07be1b290e 100644 --- a/docs/middleware/cors.md +++ b/docs/middleware/cors.md @@ -169,6 +169,12 @@ The CORS middleware works by adding the necessary CORS headers to responses from When a request comes in, the middleware first checks if it's a preflight request, which is a CORS mechanism to determine whether the actual request is safe to send. Preflight requests are HTTP OPTIONS requests with specific CORS headers. If it's a preflight request, the middleware responds with the appropriate CORS headers and ends the request. +:::note +Preflight requests are typically sent by browsers before making actual cross-origin requests, especially for methods other than GET or POST, or when custom headers are used. + +A preflight request is an HTTP OPTIONS request that includes the `Origin`, `Access-Control-Request-Method`, and optionally `Access-Control-Request-Headers` headers. The browser sends this request to check if the server allows the actual request method and headers. +::: + If it's not a preflight request, the middleware adds the CORS headers to the response and passes the request to the next handler. The actual CORS headers added depend on the configuration of the middleware. The `AllowOrigins` option controls which origins can make cross-origin requests. The middleware handles different `AllowOrigins` configurations as follows: @@ -197,6 +203,171 @@ The `MaxAge` option indicates how long the results of a preflight request can be The `Vary` header is used in this middleware to inform the client that the server's response to a request. For or both preflight and actual requests, the Vary header is set to `Access-Control-Request-Method` and `Access-Control-Request-Headers`. For preflight requests, the Vary header is also set to `Origin`. The `Vary` header is important for caching. It helps caches (like a web browser's cache or a CDN) determine when a cached response can be used in response to a future request, and when the server needs to be queried for a new response. +## Infrastructure Considerations + +When deploying Fiber applications behind infrastructure components like CDNs, API gateways, load balancers, or reverse proxies, you have two main options for handling CORS: + +### Option 1: Use Infrastructure-Level CORS (Recommended) + +**For most production deployments, it's often preferable to handle CORS at the infrastructure level** rather than in your Fiber application. This approach offers several advantages: + +- **Better Performance**: CORS headers are added at the edge, closer to the client +- **Reduced Server Load**: Preflight requests are handled without reaching your application +- **Centralized Configuration**: Manage CORS policies alongside other infrastructure settings +- **Built-in Caching**: Infrastructure providers optimize CORS response caching + +**Common infrastructure CORS solutions:** + +- **CDNs**: CloudFront, CloudFlare, Azure CDN - handle CORS at edge locations +- **API Gateways**: AWS API Gateway, Google Cloud API Gateway - centralized CORS management +- **Load Balancers**: Application Load Balancers with CORS rules +- **Reverse Proxies**: Nginx, Apache with CORS modules + +If using infrastructure-level CORS, **disable Fiber's CORS middleware** to avoid conflicts: + +```go +// Don't use both - choose one approach +// app.Use(cors.New()) // Remove this line when using infrastructure CORS +``` + +### Option 2: Application-Level CORS (Fiber Middleware) + +Use Fiber's CORS middleware when you need: + +- **Dynamic origin validation** based on application logic +- **Fine-grained control** over CORS policies per route +- **Integration with application state** (database-driven origins, etc.) +- **Development environments** where infrastructure CORS isn't available + +If choosing this approach, ensure that **all CORS headers reach your Fiber application unchanged**. + +### Required Headers for CORS Preflight Requests + +For CORS preflight requests to work correctly, these headers **must not be stripped or modified by caching layers**: + +- `Origin` - Required to identify the requesting origin +- `Access-Control-Request-Method` - Required to identify the HTTP method for the actual request +- `Access-Control-Request-Headers` - Optional, contains custom headers the actual request will use +- `Access-Control-Request-Private-Network` - Optional, for private network access requests + +:::warning Critical Preflight Requirement +If the `Access-Control-Request-Method` header is missing from an OPTIONS request, Fiber will not recognize them as CORS preflight requests. Instead, they'll be treated as regular OPTIONS requests, which typically return `405 Method Not Allowed` since most applications don't define explicit OPTIONS handlers. +::: + +### CORS Response Headers (Set by Fiber) + +The middleware sets these response headers based on your configuration: + +**For all CORS requests:** + +- `Access-Control-Allow-Origin` - Set to the allowed origin or "*" +- `Access-Control-Allow-Credentials` - Set to "true" when `AllowCredentials: true` +- `Access-Control-Expose-Headers` - Lists headers the client can access +- `Vary` - Set to "Origin" (unless wildcard origins are used) + +**For preflight responses only:** + +- `Access-Control-Allow-Methods` - Lists allowed HTTP methods +- `Access-Control-Allow-Headers` - Lists allowed request headers (or echoes the request) +- `Access-Control-Max-Age` - Cache duration for preflight results (if MaxAge > 0) +- `Access-Control-Allow-Private-Network` - Set to "true" when private network access is allowed +- `Vary` - Set to "Access-Control-Request-Method, Access-Control-Request-Headers, Origin" + +### Common Infrastructure Issues + +**CDNs (CloudFront, CloudFlare, etc.)**: + +- Configure cache policies to forward all CORS headers +- Ensure OPTIONS requests are not cached inappropriately or cache them correctly with proper Vary headers +- Don't strip or modify CORS request headers + +**API Gateways**: + +- Choose either gateway-level CORS OR application-level CORS, not both +- If using gateway CORS, disable Fiber's CORS middleware +- If forwarding to Fiber, ensure all headers pass through unchanged + +**Load Balancers/Reverse Proxies**: + +- Preserve all HTTP headers, especially CORS-related ones +- Don't modify or strip `Origin`, `Access-Control-Request-*` headers + +**WAFs/Security Services**: + +- Whitelist CORS headers in security rules +- Ensure OPTIONS requests with CORS headers aren't blocked + +### Debugging CORS Issues + +Add this middleware **before** your CORS configuration to debug what headers Fiber receives: + +```go +// Debug middleware to log CORS preflight requests +// Only use in development or testing environments +app.Use(func(c *fiber.Ctx) error { + if c.Method() == "OPTIONS" { + fmt.Printf("OPTIONS %s\n", c.Path()) + fmt.Printf(" Origin: %s\n", c.Get("Origin")) + fmt.Printf(" Access-Control-Request-Method: %s\n", c.Get("Access-Control-Request-Method")) + fmt.Printf(" Access-Control-Request-Headers: %s\n", c.Get("Access-Control-Request-Headers")) + } + return c.Next() +}) + +app.Use(cors.New(cors.Config{ + AllowOrigins: []string{"https://yourdomain.com"}, + AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, +})) +``` + +Test CORS preflight directly with curl: + +```bash +# Test preflight request +curl -X OPTIONS https://your-app.com/api/test \ + -H "Origin: https://yourdomain.com" \ + -H "Access-Control-Request-Method: POST" \ + -H "Access-Control-Request-Headers: Content-Type" \ + -v + +# Test simple CORS request +curl -X GET https://your-app.com/api/test \ + -H "Origin: https://yourdomain.com" \ + -v +``` + +### Caching Considerations + +The middleware sets appropriate `Vary` headers to ensure proper caching: + +- **Non-wildcard origins**: `Vary: Origin` is set to cache responses per origin +- **Preflight requests**: `Vary: Access-Control-Request-Method, Access-Control-Request-Headers, Origin` +- **OPTIONS without preflight headers**: `Vary: Origin` to avoid cache poisoning + +Ensure your infrastructure respects these `Vary` headers for correct caching behavior. + +### Choosing the Right Approach + +| Scenario | Recommended Approach | +|----------|---------------------| +| Production with CDN/API Gateway | Infrastructure-level CORS | +| Dynamic origin validation needed | Application-level CORS | +| Microservices with different CORS policies | Application-level CORS | +| Simple static origins | Infrastructure-level CORS | +| Development/testing | Application-level CORS | +| High traffic applications | Infrastructure-level CORS | + +:::tip Infrastructure CORS Configuration +Most cloud providers offer comprehensive CORS documentation: + +- [AWS CloudFront CORS](https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/header-caching.html#header-caching-web-cors) +- [Google Cloud CORS](https://cloud.google.com/storage/docs/cross-origin) +- [Azure CDN CORS](https://docs.microsoft.com/en-us/azure/cdn/cdn-cors) +- [CloudFlare CORS](https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip) + +Configure CORS at the infrastructure level when possible for optimal performance and reduced complexity. +::: + ## Security Considerations When configuring CORS, misconfiguration can potentially expose your application to various security risks. Here are some secure configurations and common pitfalls to avoid: diff --git a/docs/middleware/encryptcookie.md b/docs/middleware/encryptcookie.md index 004b4bee175..f2625ad0220 100644 --- a/docs/middleware/encryptcookie.md +++ b/docs/middleware/encryptcookie.md @@ -13,7 +13,7 @@ This middleware encrypts cookie values and not the cookie names. ## Signatures ```go -// Intitializes the middleware +// Initializes the middleware func New(config ...Config) fiber.Handler // GenerateKey returns a random string of 16, 24, or 32 bytes. diff --git a/docs/middleware/envvar.md b/docs/middleware/envvar.md index 4467b7349d6..897357ac55b 100644 --- a/docs/middleware/envvar.md +++ b/docs/middleware/envvar.md @@ -26,14 +26,13 @@ import ( After you initiate your Fiber app, you can use the following possibilities: ```go -// Initialize default config +// Initialize default config (exports no variables) app.Use("/expose/envvars", envvar.New()) // Or extend your config for customization app.Use("/expose/envvars", envvar.New( envvar.Config{ - ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"}, - ExcludeVars: map[string]string{"excludeKey": ""}, + ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"}, }), ) ``` @@ -60,11 +59,11 @@ Http response contract: | Property | Type | Description | Default | |:------------|:--------------------|:-----------------------------------------------------------------------------|:--------| -| ExportVars | `map[string]string` | ExportVars specifies the environment variables that should be exported. | `nil` | -| ExcludeVars | `map[string]string` | ExcludeVars specifies the environment variables that should not be exported. | `nil` | +| ExportVars | `map[string]string` | ExportVars specifies the environment variables that should be exported. | `nil` | ## Default Config ```go Config{} +// Exports no environment variables ``` diff --git a/docs/middleware/session.md b/docs/middleware/session.md index 65d1c96bd78..43ed0b38621 100644 --- a/docs/middleware/session.md +++ b/docs/middleware/session.md @@ -194,9 +194,10 @@ func DefaultErrorHandler(fiber.Ctx, err error) ### Middleware Methods ```go -func (m *Middleware) Set(key string, value any) -func (m *Middleware) Get(key string) any -func (m *Middleware) Delete(key string) +func (m *Middleware) Set(key any, value any) +func (m *Middleware) Get(key any) any +func (m *Middleware) Delete(key any) +func (m *Middleware) Keys() []any func (m *Middleware) Destroy() error func (m *Middleware) Reset() error func (m *Middleware) Store() *Store @@ -207,14 +208,15 @@ func (m *Middleware) Store() *Store ```go func (s *Session) Fresh() bool func (s *Session) ID() string -func (s *Session) Get(key string) any -func (s *Session) Set(key string, val any) +func (s *Session) Get(key any) any +func (s *Session) Set(key any, val any) +func (s *Session) Delete(key any) +func (s *Session) Keys() []any func (s *Session) Destroy() error func (s *Session) Regenerate() error func (s *Session) Release() func (s *Session) Reset() error func (s *Session) Save() error -func (s *Session) Keys() []string func (s *Session) SetIdleTimeout(idleTimeout time.Duration) ``` diff --git a/docs/whats_new.md b/docs/whats_new.md index ef2f88cb148..552a1de1fef 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -1,6 +1,6 @@ --- id: whats_new -title: 🆕 Whats New in v3 +title: 🆕 What's New in v3 sidebar_position: 2 toc_max_heading_level: 4 --- @@ -1012,6 +1012,10 @@ We've added support for `zstd` compression on top of `gzip`, `deflate`, and `bro Added support for specifying Key length when using `encryptcookie.GenerateKey(length)`. This allows the user to generate keys compatible with `AES-128`, `AES-192`, and `AES-256` (Default). +### EnvVar + +The `ExcludeVars` field has been removed from the EnvVar middleware configuration. When upgrading, remove any references to this field and explicitly list the variables you wish to expose using `ExportVars`. + ### Filesystem We've decided to remove filesystem middleware to clear up the confusion between static and filesystem middleware. diff --git a/go.mod b/go.mod index 1eb2f88eee2..df96fee8e84 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/gofiber/fiber/v3 go 1.24.0 require ( - github.com/gofiber/schema v1.4.0 + github.com/gofiber/schema v1.5.0 github.com/gofiber/utils/v2 v2.0.0-beta.8 github.com/google/uuid v1.6.0 github.com/mattn/go-colorable v0.1.14 @@ -12,7 +12,7 @@ require ( github.com/tinylib/msgp v1.3.0 github.com/valyala/bytebufferpool v1.0.0 github.com/valyala/fasthttp v1.62.0 - golang.org/x/crypto v0.38.0 + golang.org/x/crypto v0.39.0 ) require ( @@ -25,6 +25,6 @@ require ( github.com/x448/float16 v0.8.4 // indirect golang.org/x/net v0.40.0 // indirect golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.25.0 // indirect + golang.org/x/text v0.26.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bed6c09b40f..902d56be73d 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU= github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= -github.com/gofiber/schema v1.4.0 h1:WBCK0DvsyPQ3h+Cj3mOaN5vZdfnog0GSvOSCgci1K+s= -github.com/gofiber/schema v1.4.0/go.mod h1:YYwj01w3hVfaNjhtJzaqetymL56VW642YS3qZPhuE6c= +github.com/gofiber/schema v1.5.0 h1:dcbLol88CXdLFUY3K3TKp3SZ90v8CKIjgJp1/GfzwqU= +github.com/gofiber/schema v1.5.0/go.mod h1:YYwj01w3hVfaNjhtJzaqetymL56VW642YS3qZPhuE6c= github.com/gofiber/utils/v2 v2.0.0-beta.8 h1:ZifwbHZqZO3YJsx1ZhDsWnPjaQ7C0YD20LHt+DQeXOU= github.com/gofiber/utils/v2 v2.0.0-beta.8/go.mod h1:1lCBo9vEF4RFEtTgWntipnaScJZQiM8rrsYycLZ4n9c= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -32,15 +32,15 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/helpers.go b/helpers.go index 1084a8cb412..e2dcaacacf2 100644 --- a/helpers.go +++ b/helpers.go @@ -742,8 +742,8 @@ func IsMethodIdempotent(m string) bool { } // Convert a string value to a specified type, handling errors and optional default values. -func Convert[T any](value string, convertor func(string) (T, error), defaultValue ...T) (T, error) { - converted, err := convertor(value) +func Convert[T any](value string, converter func(string) (T, error), defaultValue ...T) (T, error) { + converted, err := converter(value) if err != nil { if len(defaultValue) > 0 { return defaultValue[0], nil diff --git a/listen.go b/listen.go index cc1e37201f8..d1f46f73217 100644 --- a/listen.go +++ b/listen.go @@ -72,19 +72,19 @@ type ListenConfig struct { // Default: NetworkTCP4 ListenerNetwork string `json:"listener_network"` - // CertFile is a path of certficate file. + // CertFile is a path of certificate file. // If you want to use TLS, you have to enter this field. // // Default : "" CertFile string `json:"cert_file"` - // KeyFile is a path of certficate's private key. + // KeyFile is a path of certificate's private key. // If you want to use TLS, you have to enter this field. // // Default : "" CertKeyFile string `json:"cert_key_file"` - // CertClientFile is a path of client certficate. + // CertClientFile is a path of client certificate. // If you want to use mTLS, you have to enter this field. // // Default : "" diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 3fc1f3d0ba7..fc97562f454 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -1,7 +1,6 @@ package cors import ( - "slices" "strconv" "strings" @@ -83,8 +82,9 @@ func New(config ...Config) fiber.Handler { return c.Next() } - // Get originHeader header - originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin)) + // Get origin header preserving the original case for the response + originHeaderRaw := c.Get(fiber.HeaderOrigin) + originHeader := strings.ToLower(originHeaderRaw) // If the request does not have Origin header, the request is outside the scope of CORS if originHeader == "" { @@ -115,15 +115,18 @@ func New(config ...Config) fiber.Handler { allowOrigin = "*" } else { // Check if the origin is in the list of allowed origins - if slices.Contains(allowOrigins, originHeader) { - allowOrigin = originHeader + for _, origin := range allowOrigins { + if origin == originHeader { + allowOrigin = originHeaderRaw + break + } } // Check if the origin is in the list of allowed subdomains if allowOrigin == "" { for _, sOrigin := range allowSOrigins { if sOrigin.match(originHeader) { - allowOrigin = originHeader + allowOrigin = originHeaderRaw break } } @@ -133,18 +136,18 @@ func New(config ...Config) fiber.Handler { // Run AllowOriginsFunc if the logic for // handling the value in 'AllowOrigins' does // not result in allowOrigin being set. - if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) { - allowOrigin = originHeader + if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeaderRaw) { + allowOrigin = originHeaderRaw } // Simple request - // Ommit allowMethods and allowHeaders, only used for pre-flight requests + // Omit allowMethods and allowHeaders, only used for pre-flight requests if c.Method() != fiber.MethodOptions { if !allowAllOrigins { // See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches c.Vary(fiber.HeaderOrigin) } - setSimpleHeaders(c, allowOrigin, maxAge, cfg) + setSimpleHeaders(c, allowOrigin, cfg) return c.Next() } @@ -162,7 +165,7 @@ func New(config ...Config) fiber.Handler { } c.Vary(fiber.HeaderOrigin) - setSimpleHeaders(c, allowOrigin, maxAge, cfg) + setPreflightHeaders(c, allowOrigin, maxAge, cfg) // Set Preflight headers if len(cfg.AllowMethods) > 0 { @@ -183,7 +186,7 @@ func New(config ...Config) fiber.Handler { } // Function to set Simple CORS headers -func setSimpleHeaders(c fiber.Ctx, allowOrigin, maxAge string, cfg Config) { +func setSimpleHeaders(c fiber.Ctx, allowOrigin string, cfg Config) { if cfg.AllowCredentials { // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' if allowOrigin == "*" { @@ -198,15 +201,20 @@ func setSimpleHeaders(c fiber.Ctx, allowOrigin, maxAge string, cfg Config) { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) } + // Set Expose-Headers if not empty + if len(cfg.ExposeHeaders) > 0 { + c.Set(fiber.HeaderAccessControlExposeHeaders, strings.Join(cfg.ExposeHeaders, ", ")) + } +} + +// Function to set Preflight CORS headers +func setPreflightHeaders(c fiber.Ctx, allowOrigin, maxAge string, cfg Config) { + setSimpleHeaders(c, allowOrigin, cfg) + // Set MaxAge if set if cfg.MaxAge > 0 { c.Set(fiber.HeaderAccessControlMaxAge, maxAge) } else if cfg.MaxAge < 0 { c.Set(fiber.HeaderAccessControlMaxAge, "0") } - - // Set Expose-Headers if not empty - if len(cfg.ExposeHeaders) > 0 { - c.Set(fiber.HeaderAccessControlExposeHeaders, strings.Join(cfg.ExposeHeaders, ", ")) - } } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index e4e963481e7..c182f22c720 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -66,6 +66,37 @@ func Test_CORS_Negative_MaxAge(t *testing.T) { require.Equal(t, "0", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) } +func Test_CORS_MaxAge_NotSetOnSimpleRequest(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Use(New(Config{MaxAge: 100})) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + app.Handler()(ctx) + + require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) +} + +func Test_CORS_Preserve_Origin_Case(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Use(New(Config{AllowOrigins: []string{"http://example.com"}})) + + origin := "HTTP://EXAMPLE.COM" + + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderOrigin, origin) + app.Handler()(ctx) + + require.Equal(t, origin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) +} + func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { t.Helper() @@ -857,7 +888,7 @@ func Test_CORS_AllowCredentials(t *testing.T) { }, RequestOrigin: "*", ResponseOrigin: "*", - // Middleware will validate that wildcard wont set credentials to true + // Middleware will validate that wildcard won't set credentials to true ResponseCredentials: "", }, { diff --git a/middleware/earlydata/earlydata_test.go b/middleware/earlydata/earlydata_test.go index 55e800ff2b1..3a5a03f67c5 100644 --- a/middleware/earlydata/earlydata_test.go +++ b/middleware/earlydata/earlydata_test.go @@ -1,14 +1,15 @@ -package earlydata_test +package earlydata import ( "errors" "fmt" "net/http/httptest" + "reflect" "testing" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/middleware/earlydata" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" ) const ( @@ -28,12 +29,12 @@ func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App { app = fiber.New(*c) } - app.Use(earlydata.New()) + app.Use(New()) // Middleware to test IsEarly func const localsKeyTestValid = "earlydata_testvalid" app.Use(func(c fiber.Ctx) error { - isEarly := earlydata.IsEarly(c) + isEarly := IsEarly(c) switch h := c.Get(headerName); h { case "", headerValOff: @@ -49,7 +50,7 @@ func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App { } default: if isEarly { - return errors.New("early-data unsuported on unsafe HTTP methods") + return errors.New("early-data unsupported on unsafe HTTP methods") } } @@ -191,3 +192,71 @@ func Test_EarlyData(t *testing.T) { trustedRun(t, app) }) } + +// Test_EarlyDataNext verifies that the middleware skips its logic when Next returns true. +func Test_EarlyDataNext(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Next: func(fiber.Ctx) bool { return true }, + })) + + called := false + app.Get("/", func(c fiber.Ctx) error { + called = true + if IsEarly(c) { + return errors.New("IsEarly(c) should be false when Next returns true") + } + return nil + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Set(headerName, headerValOn) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.True(t, called) +} + +// Test_configDefault_NoConfig verifies that calling configDefault without +// providing a configuration returns ConfigDefault as-is. +func Test_configDefault_NoConfig(t *testing.T) { + t.Parallel() + cfg := configDefault() + require.Equal(t, ConfigDefault.Error, cfg.Error) + require.Equal(t, reflect.ValueOf(ConfigDefault.IsEarlyData).Pointer(), reflect.ValueOf(cfg.IsEarlyData).Pointer()) + require.Equal(t, reflect.ValueOf(ConfigDefault.AllowEarlyData).Pointer(), reflect.ValueOf(cfg.AllowEarlyData).Pointer()) +} + +// Test_configDefault_WithConfig verifies that provided configuration fields are +// kept while missing fields are populated with defaults. +func Test_configDefault_WithConfig(t *testing.T) { + t.Parallel() + expectedErr := errors.New("boom") + called := false + custom := Config{ + Next: func(_ fiber.Ctx) bool { called = true; return false }, + Error: expectedErr, + } + + cfg := configDefault(custom) + + // Next should be preserved and not invoked by configDefault. + require.False(t, called) + require.Equal(t, reflect.ValueOf(custom.Next).Pointer(), reflect.ValueOf(cfg.Next).Pointer()) + // Custom error must be preserved. + require.Equal(t, expectedErr, cfg.Error) + // Missing fields should be set to defaults. + require.NotNil(t, cfg.IsEarlyData) + require.NotNil(t, cfg.AllowEarlyData) + + // Verify default functions behave as expected. + app := fiber.New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + c.Request().Header.Set(DefaultHeaderName, DefaultHeaderTrueValue) + c.Request().Header.SetMethod(fiber.MethodGet) + require.True(t, cfg.IsEarlyData(c)) + require.True(t, cfg.AllowEarlyData(c)) + app.ReleaseCtx(c) +} diff --git a/middleware/envvar/envvar.go b/middleware/envvar/envvar.go index 4b9aaf802a2..f6c5a4e2073 100644 --- a/middleware/envvar/envvar.go +++ b/middleware/envvar/envvar.go @@ -2,7 +2,6 @@ package envvar import ( "os" - "strings" "github.com/gofiber/fiber/v3" ) @@ -11,8 +10,6 @@ import ( type Config struct { // ExportVars specifies the environment variables that should export ExportVars map[string]string - // ExcludeVars specifies the environment variables that should not export - ExcludeVars map[string]string } type EnvVar struct { @@ -47,20 +44,16 @@ func New(config ...Config) fiber.Handler { func newEnvVar(cfg Config) *EnvVar { vars := &EnvVar{Vars: make(map[string]string)} - if len(cfg.ExportVars) > 0 { - for key, defaultVal := range cfg.ExportVars { - vars.set(key, defaultVal) - if envVal, exists := os.LookupEnv(key); exists { - vars.set(key, envVal) - } - } - } else { - const numElems = 2 - for _, envVal := range os.Environ() { - keyVal := strings.SplitN(envVal, "=", numElems) - if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists { - vars.set(keyVal[0], keyVal[1]) - } + if len(cfg.ExportVars) == 0 { + // do not expose environment variables when no configuration + // is supplied to prevent accidental information disclosure + return vars + } + + for key, defaultVal := range cfg.ExportVars { + vars.set(key, defaultVal) + if envVal, exists := os.LookupEnv(key); exists { + vars.set(key, envVal) } } diff --git a/middleware/envvar/envvar_test.go b/middleware/envvar/envvar_test.go index e34969b159d..eef65b1f146 100644 --- a/middleware/envvar/envvar_test.go +++ b/middleware/envvar/envvar_test.go @@ -11,19 +11,15 @@ import ( "github.com/stretchr/testify/require" ) -func Test_EnvVarStructWithExportVarsExcludeVars(t *testing.T) { +func Test_EnvVarStructWithExportVars(t *testing.T) { t.Setenv("testKey", "testEnvValue") t.Setenv("anotherEnvKey", "anotherEnvVal") - t.Setenv("excludeKey", "excludeEnvValue") - vars := newEnvVar(Config{ - ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"}, - ExcludeVars: map[string]string{"excludeKey": ""}, + ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"}, }) require.Equal(t, "testEnvValue", vars.Vars["testKey"]) require.Equal(t, "testDefaultVal", vars.Vars["testDefaultKey"]) - require.Equal(t, "", vars.Vars["excludeKey"]) require.Equal(t, "", vars.Vars["anotherEnvKey"]) } @@ -92,8 +88,8 @@ func Test_EnvVarHandlerDefaultConfig(t *testing.T) { var envVars EnvVar require.NoError(t, json.Unmarshal(respBody, &envVars)) - val := envVars.Vars["testEnvKey"] - require.Equal(t, "testEnvVal", val) + _, exists := envVars.Vars["testEnvKey"] + require.False(t, exists) } func Test_EnvVarHandlerMethod(t *testing.T) { @@ -113,8 +109,8 @@ func Test_EnvVarHandlerSpecialValue(t *testing.T) { t.Setenv(testEnvKey, fakeBase64) app := fiber.New() - app.Use("/envvars", New()) app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}})) + app.Use("/envvars", New()) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil) require.NoError(t, err) @@ -126,8 +122,8 @@ func Test_EnvVarHandlerSpecialValue(t *testing.T) { var envVars EnvVar require.NoError(t, json.Unmarshal(respBody, &envVars)) - val := envVars.Vars[testEnvKey] - require.Equal(t, fakeBase64, val) + _, exists := envVars.Vars[testEnvKey] + require.False(t, exists) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", nil) require.NoError(t, err) @@ -139,6 +135,6 @@ func Test_EnvVarHandlerSpecialValue(t *testing.T) { var envVarsExport EnvVar require.NoError(t, json.Unmarshal(respBody, &envVarsExport)) - val = envVarsExport.Vars[testEnvKey] + val := envVarsExport.Vars[testEnvKey] require.Equal(t, fakeBase64, val) } diff --git a/middleware/helmet/helmet_test.go b/middleware/helmet/helmet_test.go index c9d4b1c4578..674b51a5724 100644 --- a/middleware/helmet/helmet_test.go +++ b/middleware/helmet/helmet_test.go @@ -6,6 +6,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" ) func Test_Default(t *testing.T) { @@ -199,3 +200,61 @@ func Test_PermissionsPolicy(t *testing.T) { require.NoError(t, err) require.Equal(t, "microphone=()", resp.Header.Get(fiber.HeaderPermissionsPolicy)) } + +func Test_HSTSHeaders(t *testing.T) { + hstsAge := 60 + app := fiber.New() + + app.Use(New(Config{HSTSMaxAge: hstsAge})) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + handler := app.Handler() + ctx := &fasthttp.RequestCtx{} + + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetProtocol("https") + + handler(ctx) + + require.Equal(t, "max-age=60; includeSubDomains", string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity))) + + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetProtocol("http") + + handler(ctx) + + require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity))) +} + +func Test_HSTSExcludeSubdomainsAndPreload(t *testing.T) { + hstsAge := 31536000 + app := fiber.New() + + app.Use(New(Config{ + HSTSMaxAge: hstsAge, + HSTSExcludeSubdomains: true, + HSTSPreloadEnabled: true, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + + handler := app.Handler() + ctx := &fasthttp.RequestCtx{} + + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.SetProtocol("https") + + handler(ctx) + + require.Equal(t, "max-age=31536000; preload", string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity))) +} diff --git a/middleware/idempotency/idempotency_test.go b/middleware/idempotency/idempotency_test.go index a4fa0bb0f6c..f124fc144f7 100644 --- a/middleware/idempotency/idempotency_test.go +++ b/middleware/idempotency/idempotency_test.go @@ -1,8 +1,10 @@ -package idempotency_test +package idempotency import ( "errors" + "fmt" "io" + "net/http" "net/http/httptest" "strconv" "sync" @@ -11,13 +13,14 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/middleware/idempotency" "github.com/valyala/fasthttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +const validKey = "00000000-0000-0000-0000-000000000000" + // go test -run Test_Idempotency func Test_Idempotency(t *testing.T) { t.Parallel() @@ -29,7 +32,7 @@ func Test_Idempotency(t *testing.T) { } isMethodSafe := fiber.IsMethodSafe(c.Method()) - isIdempotent := idempotency.IsFromCache(c) || idempotency.WasPutToCache(c) + isIdempotent := IsFromCache(c) || WasPutToCache(c) hasReqHeader := c.Get("X-Idempotency-Key") != "" if isMethodSafe { @@ -53,7 +56,7 @@ func Test_Idempotency(t *testing.T) { // Needs to be at least a second as the memory storage doesn't support shorter durations. const lifetime = 2 * time.Second - app.Use(idempotency.New(idempotency.Config{ + app.Use(New(Config{ Lifetime: lifetime, })) @@ -136,7 +139,7 @@ func Benchmark_Idempotency(b *testing.B) { // Needs to be at least a second as the memory storage doesn't support shorter durations. const lifetime = 1 * time.Second - app.Use(idempotency.New(idempotency.Config{ + app.Use(New(Config{ Lifetime: lifetime, })) @@ -169,3 +172,247 @@ func Benchmark_Idempotency(b *testing.B) { } }) } + +func Test_configDefault_defaults(t *testing.T) { + t.Parallel() + + cfg := configDefault() + require.NotNil(t, cfg.Lock) + require.NotNil(t, cfg.Storage) + require.Equal(t, ConfigDefault.Lifetime, cfg.Lifetime) + require.Equal(t, ConfigDefault.KeyHeader, cfg.KeyHeader) + require.Nil(t, cfg.KeepResponseHeaders) + + app := fiber.New() + + fctx := &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod(fiber.MethodGet) + ctx := app.AcquireCtx(fctx) + require.True(t, cfg.Next(ctx)) + app.ReleaseCtx(ctx) + + fctx = &fasthttp.RequestCtx{} + fctx.Request.Header.SetMethod(fiber.MethodPost) + ctx = app.AcquireCtx(fctx) + require.False(t, cfg.Next(ctx)) + app.ReleaseCtx(ctx) + + require.NoError(t, cfg.KeyHeaderValidate(validKey)) + require.Error(t, cfg.KeyHeaderValidate("short")) +} + +func Test_configDefault_override(t *testing.T) { + t.Parallel() + + l := &stubLock{} + s := &stubStorage{} + + cfg := configDefault(Config{ + Lifetime: 42 * time.Second, + KeyHeader: "Foo", + KeepResponseHeaders: []string{}, + Lock: l, + Storage: s, + }) + + require.Equal(t, 42*time.Second, cfg.Lifetime) + require.Equal(t, "Foo", cfg.KeyHeader) + require.Nil(t, cfg.KeepResponseHeaders) + require.Equal(t, l, cfg.Lock) + require.Equal(t, s, cfg.Storage) + require.NotNil(t, cfg.Next) + require.NotNil(t, cfg.KeyHeaderValidate) +} + +// helper to perform request +func do(app *fiber.App, req *http.Request) (*http.Response, string) { + resp, err := app.Test(req, fiber.TestConfig{Timeout: 5 * time.Second}) + if err != nil { + panic(err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + return resp, string(body) +} + +func Test_New_NextSkip(t *testing.T) { + t.Parallel() + app := fiber.New() + var count int + + app.Use(New(Config{Next: func(_ fiber.Ctx) bool { return true }})) + + app.Post("/", func(c fiber.Ctx) error { + count++ + return c.SendString(strconv.Itoa(count)) + }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + _, body1 := do(app, req) + + req2 := httptest.NewRequest(http.MethodPost, "/", nil) + req2.Header.Set(ConfigDefault.KeyHeader, validKey) + _, body2 := do(app, req2) + + require.Equal(t, "1", body1) + require.Equal(t, "2", body2) +} + +func Test_New_InvalidKey(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New()) + app.Post("/", func(_ fiber.Ctx) error { return nil }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, "bad") + resp, body := do(app, req) + + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Contains(t, body, "invalid length") +} + +func Test_New_StorageGetError(t *testing.T) { + t.Parallel() + app := fiber.New() + s := &stubStorage{getErr: errors.New("boom")} + app.Use(New(Config{Storage: s, Lock: &stubLock{}})) + app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Contains(t, body, "failed to write cached response at fastpath") +} + +func Test_New_UnmarshalError(t *testing.T) { + t.Parallel() + app := fiber.New() + s := &stubStorage{data: map[string][]byte{validKey: []byte("bad")}} + app.Use(New(Config{Storage: s, Lock: &stubLock{}})) + app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Contains(t, body, "failed to write cached response at fastpath") +} + +func Test_New_StoreRetrieve_FilterHeaders(t *testing.T) { + t.Parallel() + app := fiber.New() + s := &stubStorage{} + app.Use(New(Config{ + Storage: s, + Lock: &stubLock{}, + KeepResponseHeaders: []string{"Foo"}, + })) + + var count int + app.Post("/", func(c fiber.Ctx) error { + count++ + c.Set("Foo", "foo") + c.Set("Bar", "bar") + return c.SendString(fmt.Sprintf("resp%d", count)) + }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + require.Equal(t, "resp1", body) + require.Equal(t, "foo", resp.Header.Get("Foo")) + require.Equal(t, "bar", resp.Header.Get("Bar")) + + req2 := httptest.NewRequest(http.MethodPost, "/", nil) + req2.Header.Set(ConfigDefault.KeyHeader, validKey) + resp2, body2 := do(app, req2) + require.Equal(t, "resp1", body2) + require.Equal(t, "foo", resp2.Header.Get("Foo")) + require.Empty(t, resp2.Header.Get("Bar")) + require.Equal(t, 1, count) + require.Equal(t, 1, s.setCount) +} + +func Test_New_HandlerError(t *testing.T) { + t.Parallel() + app := fiber.New() + s := &stubStorage{} + app.Use(New(Config{Storage: s, Lock: &stubLock{}})) + app.Post("/", func(_ fiber.Ctx) error { return errors.New("boom") }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Equal(t, "boom", body) + require.Equal(t, 0, s.setCount) + + resp2, body2 := do(app, req) + require.Equal(t, fiber.StatusInternalServerError, resp2.StatusCode) + require.Equal(t, "boom", body2) + require.Equal(t, 0, s.setCount) +} + +func Test_New_LockError(t *testing.T) { + t.Parallel() + app := fiber.New() + l := &stubLock{lockErr: errors.New("fail")} + app.Use(New(Config{Lock: l, Storage: &stubStorage{}})) + app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Contains(t, body, "failed to lock") +} + +func Test_New_StorageSetError(t *testing.T) { + t.Parallel() + app := fiber.New() + s := &stubStorage{setErr: errors.New("nope")} + app.Use(New(Config{Storage: s, Lock: &stubLock{}})) + app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Contains(t, body, "failed to save response") +} + +func Test_New_UnlockError(t *testing.T) { + t.Parallel() + app := fiber.New() + l := &stubLock{unlockErr: errors.New("u")} + app.Use(New(Config{Lock: l, Storage: &stubStorage{}})) + app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Equal(t, "ok", body) +} + +func Test_New_SecondPassReadError(t *testing.T) { + t.Parallel() + app := fiber.New() + s := &stubStorage{} + l := &stubLock{afterLock: func() { s.getErr = errors.New("g") }} + app.Use(New(Config{Lock: l, Storage: s})) + app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(ConfigDefault.KeyHeader, validKey) + resp, body := do(app, req) + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Contains(t, body, "failed to write cached response while locked") +} diff --git a/middleware/idempotency/stub_test.go b/middleware/idempotency/stub_test.go new file mode 100644 index 00000000000..48aec611d38 --- /dev/null +++ b/middleware/idempotency/stub_test.go @@ -0,0 +1,64 @@ +package idempotency + +import ( + "time" +) + +// stubLock implements Locker for testing purposes. +type stubLock struct { + lockErr error + unlockErr error + afterLock func() +} + +func (s *stubLock) Lock(string) error { + if s.afterLock != nil { + s.afterLock() + } + return s.lockErr +} +func (s *stubLock) Unlock(string) error { return s.unlockErr } + +// stubStorage implements fiber.Storage for testing. +type stubStorage struct { + data map[string][]byte + getErr error + setErr error + setCount int +} + +func (s *stubStorage) Get(key string) ([]byte, error) { + if s.getErr != nil { + return nil, s.getErr + } + if s.data == nil { + return nil, nil + } + return s.data[key], nil +} + +func (s *stubStorage) Set(key string, val []byte, _ time.Duration) error { + if s.setErr != nil { + return s.setErr + } + if s.data == nil { + s.data = make(map[string][]byte) + } + s.data[key] = val + s.setCount++ + return nil +} + +func (s *stubStorage) Delete(key string) error { + if s.data != nil { + delete(s.data, key) + } + return nil +} + +func (s *stubStorage) Reset() error { + s.data = make(map[string][]byte) + return nil +} + +func (*stubStorage) Close() error { return nil } diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index c14bc19efe2..fe149ec5979 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -232,6 +232,21 @@ func (m *Middleware) Delete(key any) { m.Session.Delete(key) } +// Keys returns all keys in the current session. +// +// Returns: +// - []any: A slice of all keys in the session. +// +// Usage: +// +// keys := m.Keys() +func (m *Middleware) Keys() []any { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.Session.Keys() +} + // Destroy destroys the session. // // Returns: diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index 9cbb8cd53bf..6c9d01f69b7 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -1,6 +1,8 @@ package session import ( + "fmt" + "sort" "strings" "sync" "testing" @@ -86,6 +88,41 @@ func Test_Session_Middleware(t *testing.T) { return c.SendStatus(fiber.StatusInternalServerError) }) + app.Post("/keys", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + // get a value from the body + value := c.FormValue("keys") + for _, rawKey := range strings.Split(value, ",") { + key := strings.TrimSpace(rawKey) + if key == "" { + continue + } + // Set each key in the session + sess.Set(key, "value_"+key) + } + return c.SendStatus(fiber.StatusOK) + }) + + app.Get("/keys", func(c fiber.Ctx) error { + sess := FromContext(c) + if sess == nil { + return c.SendStatus(fiber.StatusInternalServerError) + } + keys := sess.Keys() + if len(keys) == 0 { + return c.SendStatus(fiber.StatusNotFound) + } + // Keys may be of any type, so convert to string for display + strKeys := []string{} + for _, key := range keys { + strKeys = append(strKeys, fmt.Sprintf("%v", key)) + } + return c.SendString("keys=" + strings.Join(strKeys, ",")) + }) + // Test GET, SET, DELETE, RESET, DESTROY by sending requests to the respective routes ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) @@ -207,6 +244,34 @@ func Test_Session_Middleware(t *testing.T) { require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") newToken = newTokenParts[1] require.NotEqual(t, token, newToken) + + token = newToken + + // Test POST /keys to set multiple keys + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodPost) + ctx.Request.SetRequestURI("/keys") + ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type + ctx.Request.SetBodyString("keys=key1,key2") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + + // Test GET /keys to check if the session has the keys + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.SetRequestURI("/keys") + ctx.Request.Header.SetCookie("session_id", token) + h(ctx) + require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) + body := string(ctx.Response.Body()) + require.True(t, strings.HasPrefix(body, "keys=")) + parts = strings.Split(strings.TrimPrefix(body, "keys="), ",") + require.Len(t, parts, 2, "Expected two keys in the session") + sort.Strings(parts) + require.Equal(t, []string{"key1", "key2"}, parts) } func Test_Session_NewWithStore(t *testing.T) { diff --git a/middleware/session/session.go b/middleware/session/session.go index ffb5c527228..094238ddae7 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -325,7 +325,7 @@ func (s *Session) saveSession() error { // Keys retrieves all keys in the current session. // // Returns: -// - []string: A slice of all keys in the session. +// - []any: A slice of all keys in the session. // // Usage: // diff --git a/middleware/static/static_test.go b/middleware/static/static_test.go index 28ef285ef28..998f21a5bac 100644 --- a/middleware/static/static_test.go +++ b/middleware/static/static_test.go @@ -740,7 +740,7 @@ func Test_Static_Compress(t *testing.T) { for _, algo := range algorithms { t.Run(algo+"_compression", func(t *testing.T) { t.Parallel() - // request non-compressable file (less than 200 bytes), Content Lengh will remain the same + // request non-compressible file (less than 200 bytes), Content Length will remain the same req := httptest.NewRequest(fiber.MethodGet, "/css/style.css", nil) req.Header.Set("Accept-Encoding", algo) resp, err := app.Test(req, testConfig) @@ -750,7 +750,7 @@ func Test_Static_Compress(t *testing.T) { require.Equal(t, "", resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "46", resp.Header.Get(fiber.HeaderContentLength)) - // request compressable file, ContentLenght will change + // request compressible file, ContentLength will change req = httptest.NewRequest(fiber.MethodGet, "/index.html", nil) req.Header.Set("Accept-Encoding", algo) resp, err = app.Test(req, testConfig) @@ -772,7 +772,7 @@ func Test_Static_Compress_WithoutEncoding(t *testing.T) { CacheDuration: 1 * time.Second, })) - // request compressable file without encoding + // request compressible file without encoding req := httptest.NewRequest(fiber.MethodGet, "/index.html", nil) resp, err := app.Test(req, testConfig) @@ -781,7 +781,7 @@ func Test_Static_Compress_WithoutEncoding(t *testing.T) { require.Equal(t, "", resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "299", resp.Header.Get(fiber.HeaderContentLength)) - // request compressable file with different encodings + // request compressible file with different encodings algorithms := []string{"zstd", "gzip", "br"} fileSuffixes := map[string]string{ "gzip": ".fiber.gz", @@ -827,7 +827,7 @@ func Test_Static_Compress_WithFileSuffixes(t *testing.T) { CacheDuration: 1 * time.Second, })) - // request compressable file with different encodings + // request compressible file with different encodings algorithms := []string{"zstd", "gzip", "br"} for _, algo := range algorithms { diff --git a/path.go b/path.go index 91b185e6828..2c98027093a 100644 --- a/path.go +++ b/path.go @@ -260,7 +260,7 @@ func addParameterMetaInfo(segs []*routeSegment) []*routeSegment { segs[i].PartCount += strings.Count(segs[j].Const, segs[i].ComparePart) } } - // check if the end of the segment is a optional slash and then if the segement is optional or the last one + // check if the end of the segment is an optional slash and then if the segment is optional or the last one } else if segs[i].Const[len(segs[i].Const)-1] == slashDelimiter && (segs[i].IsLast || (segLen > i+1 && segs[i+1].IsOptional)) { segs[i].HasOptionalSlash = true } diff --git a/path_test.go b/path_test.go index 1a3b229d291..fc112d369c4 100644 --- a/path_test.go +++ b/path_test.go @@ -58,7 +58,7 @@ func Test_Path_parseRoute(t *testing.T) { params: []string{"name"}, }, rp) - // heavy test with escaped charaters + // heavy test with escaped characters rp = parseRoute("/v1/some/resource/name\\\\:customVerb?\\?/:param/*") require.Equal(t, routeParser{ segs: []*routeSegment{ diff --git a/prefork.go b/prefork.go index 745ed30627c..b6ed594141d 100644 --- a/prefork.go +++ b/prefork.go @@ -77,12 +77,12 @@ func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg ListenConfig) er } // create variables maxProcs := runtime.GOMAXPROCS(0) - childs := make(map[int]*exec.Cmd) + children := make(map[int]*exec.Cmd) channel := make(chan child, maxProcs) // kill child procs when master exits defer func() { - for _, proc := range childs { + for _, proc := range children { if err := proc.Process.Kill(); err != nil { if !errors.Is(err, os.ErrProcessDone) { log.Errorf("prefork: failed to kill child: %v", err) @@ -117,7 +117,7 @@ func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg ListenConfig) er // store child process pid := cmd.Process.Pid - childs[pid] = cmd + children[pid] = cmd pids = append(pids, strconv.Itoa(pid)) // execute fork hook