diff --git a/AGENTS.md b/AGENTS.md index 580bf1fa527..4b23d8d8704 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -2,16 +2,16 @@ ## Agent Instructions -> **Note for the AI agent:** Before running the startup script, check for a `vendor` directory. If it exists, skip fetching modules; otherwise proceed. This vendor-check logic is for the agent’s decision-making and is not part of the user-facing script. +> **Note for the AI agent:** Before running the startup script, verify that all dependencies and CLI tools are already available. If a `vendor/` directory exists **or** `go list -m -mod=readonly all` succeeds (indicating modules are cached), you may skip the entire startup script—including all `go mod` commands and `go install` steps. Otherwise run the script as listed below. --- -## Startup script +## Startup script (reference only – do not run) - Fetch dependencies: ```bash - go get ./... + go mod tidy && go mod download && go mod vendor ``` - Install CLI tools referenced in Makefile: @@ -25,3 +25,20 @@ go install github.com/dkorunic/betteralign/cmd/betteralign@latest # struct alignment go mod tidy # clean up go.mod & go.sum ``` + +## Makefile commands + +Use `make help` to list all available commands. Common targets include: + +- **audit**: run `go mod verify`, `go vet`, and `govulncheck` for quality checks. +- **benchmark**: run benchmarks with `go test`. +- **coverage**: generate a coverage report. +- **format**: apply formatting using `gofumpt`. +- **lint**: execute `golangci-lint`. +- **test**: run the test suite with `gotestsum`. +- **longtest**: run the test suite 15 times with shuffling enabled. +- **tidy**: clean and tidy dependencies. +- **betteralign**: optimize struct field alignment. +- **generate**: run `go generate` after installing msgp and ifacemaker. + +These targets can be invoked via `make ` as needed during development and testing. diff --git a/app_test.go b/app_test.go index 1131ef52e61..6e154858ed4 100644 --- a/app_test.go +++ b/app_test.go @@ -529,7 +529,7 @@ func Test_App_Use_CaseSensitive(t *testing.T) { require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") - // right letters in the requrested route -> 200 + // right letters in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/abc", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") @@ -565,7 +565,7 @@ func Test_App_Not_Use_StrictRouting(t *testing.T) { require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") - // right path in the requrested route -> 200 + // right path in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/abc", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") @@ -575,7 +575,7 @@ func Test_App_Not_Use_StrictRouting(t *testing.T) { require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") - // right path with group in the requrested route -> 200 + // right path with group in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo/", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") @@ -645,7 +645,7 @@ func Test_App_Use_StrictRouting(t *testing.T) { require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") - // right path in the requrested route -> 200 + // right path in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/abc", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") @@ -655,7 +655,7 @@ func Test_App_Use_StrictRouting(t *testing.T) { require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") - // right path with group in the requrested route -> 200 + // right path with group in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo/", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") diff --git a/ctx.go b/ctx.go index 5849501c260..31361ff1122 100644 --- a/ctx.go +++ b/ctx.go @@ -39,12 +39,14 @@ const ( maxDetectionPaths = 3 ) +var ( + _ io.Writer = (*DefaultCtx)(nil) // Compile-time check + _ context.Context = (*DefaultCtx)(nil) // Compile-time check +) + // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. -type contextKey int - -// userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx -const userContextKey contextKey = 0 // __local_user_context__ +type contextKey int //nolint:unused // need for future (nolintlint) // DefaultCtx is the default implementation of the Ctx interface // generation tool `go install github.com/vburenin/ifacemaker@975a95966976eeb2d4365a7fb236e274c54da64c` @@ -391,23 +393,6 @@ func (c *DefaultCtx) RequestCtx() *fasthttp.RequestCtx { return c.fasthttp } -// Context returns a context implementation that was set by -// user earlier or returns a non-nil, empty context,if it was not set earlier. -func (c *DefaultCtx) Context() context.Context { - ctx, ok := c.fasthttp.UserValue(userContextKey).(context.Context) - if !ok { - ctx = context.Background() - c.SetContext(ctx) - } - - return ctx -} - -// SetContext sets a context implementation by user. -func (c *DefaultCtx) SetContext(ctx context.Context) { - c.fasthttp.SetUserValue(userContextKey, ctx) -} - // Cookie sets a cookie by passing a cookie struct. func (c *DefaultCtx) Cookie(cookie *Cookie) { fcookie := fasthttp.AcquireCookie() @@ -444,6 +429,28 @@ func (c *DefaultCtx) Cookie(cookie *Cookie) { fasthttp.ReleaseCookie(fcookie) } +// Deadline returns the time when work done on behalf of this context +// should be canceled. Deadline returns ok==false when no deadline is +// set. Successive calls to Deadline return the same results. +// +// Due to current limitations in how fasthttp works, Deadline operates as a nop. +// See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 +func (*DefaultCtx) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done returns a channel that's closed when work done on behalf of this +// context should be canceled. Done may return nil if this context can +// never be canceled. Successive calls to Done return the same value. +// The close of the Done channel may happen asynchronously, +// after the cancel function returns. +// +// Due to current limitations in how fasthttp works, Done operates as a nop. +// See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 +func (*DefaultCtx) Done() <-chan struct{} { + return nil +} + // Cookies are used for getting a cookie value by key. // Defaults to the empty string "" if the cookie doesn't exist. // If a default value is given, it will return that value if the cookie doesn't exist. @@ -468,6 +475,18 @@ func (c *DefaultCtx) Download(file string, filename ...string) error { return c.SendFile(file) } +// If Done is not yet closed, Err returns nil. +// If Done is closed, Err returns a non-nil error explaining why: +// context.DeadlineExceeded if the context's deadline passed, +// or context.Canceled if the context was canceled for some other reason. +// After Err returns a non-nil error, successive calls to Err return the same error. +// +// Due to current limitations in how fasthttp works, Err operates as a nop. +// See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 +func (*DefaultCtx) Err() error { + return nil +} + // Request return the *fasthttp.Request object // This allows you to use all fasthttp request methods // https://godoc.org/github.com/valyala/fasthttp#Request @@ -644,9 +663,14 @@ func (c *DefaultCtx) Get(key string, defaultValue ...string) string { // GetReqHeader returns the HTTP request header specified by filed. // This function is generic and can handle different headers type values. +// If the generic type cannot be matched to a supported type, the function +// returns the default value (if provided) or the zero value of type V. func GetReqHeader[V GenericType](c Ctx, key string, defaultValue ...V) V { - var v V - return genericParseType[V](c.App().getString(c.Request().Header.Peek(key)), v, defaultValue...) + v, err := genericParseType[V](c.App().getString(c.Request().Header.Peek(key))) + if err != nil && len(defaultValue) > 0 { + return defaultValue[0] + } + return v } // GetRespHeader returns the HTTP response header specified by field. @@ -1103,6 +1127,8 @@ func (c *DefaultCtx) Params(key string, defaultValue ...string) string { // Params is used to get the route parameters. // This function is generic and can handle different route parameters type values. +// If the generic type cannot be matched to a supported type, the function +// returns the default value (if provided) or the zero value of type V. // // Example: // @@ -1115,8 +1141,11 @@ func (c *DefaultCtx) Params(key string, defaultValue ...string) string { // http://example.com/id/:number -> http://example.com/id/john // Params[int](c, "number", 0) -> returns 0 because can't parse 'john' as integer. func Params[V GenericType](c Ctx, key string, defaultValue ...V) V { - var v V - return genericParseType(c.Params(key), v, defaultValue...) + v, err := genericParseType[V](c.Params(key)) + if err != nil && len(defaultValue) > 0 { + return defaultValue[0] + } + return v } // Path returns the path part of the request URL. @@ -1238,10 +1267,12 @@ func (c *DefaultCtx) Queries() map[string]string { // age := Query[int](c, "age") // Returns 8 // unknown := Query[string](c, "unknown", "default") // Returns "default" since the query parameter "unknown" is not found func Query[V GenericType](c Ctx, key string, defaultValue ...V) V { - var v V q := c.App().getString(c.RequestCtx().QueryArgs().Peek(key)) - - return genericParseType[V](q, v, defaultValue...) + v, err := genericParseType[V](q) + if err != nil && len(defaultValue) > 0 { + return defaultValue[0] + } + return v } // Range returns a struct containing the type and a slice of ranges. @@ -1804,6 +1835,12 @@ func (c *DefaultCtx) Vary(fields ...string) { c.Append(HeaderVary, fields...) } +// Value makes it possible to retrieve values (Locals) under keys scoped to the request +// and therefore available to all following routes that match the request. +func (c *DefaultCtx) Value(key any) any { + return c.fasthttp.UserValue(key) +} + // Write appends p into response body. func (c *DefaultCtx) Write(p []byte) (int, error) { c.fasthttp.Response.AppendBody(p) diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index 5d824b1d41c..c14981d2061 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -4,10 +4,10 @@ package fiber import ( "bufio" - "context" "crypto/tls" "io" "mime/multipart" + "time" "github.com/valyala/fasthttp" ) @@ -49,11 +49,6 @@ type Ctx interface { // RequestCtx returns *fasthttp.RequestCtx that carries a deadline // a cancellation signal, and other values across API boundaries. RequestCtx() *fasthttp.RequestCtx - // Context returns a context implementation that was set by - // user earlier or returns a non-nil, empty context,if it was not set earlier. - Context() context.Context - // SetContext sets a context implementation by user. - SetContext(ctx context.Context) // Cookie sets a cookie by passing a cookie struct. Cookie(cookie *Cookie) // Cookies are used for getting a cookie value by key. @@ -62,11 +57,27 @@ type Ctx interface { // The returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. Cookies(key string, defaultValue ...string) string + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // The close of the Done channel may happen asynchronously, + // after the cancel function returns. + Done() <-chan struct{} // Download transfers the file from path as an attachment. // Typically, browsers will prompt the user for download. // By default, the Content-Disposition header filename= parameter is the filepath (this typically appears in the browser dialog). // Override this default with the filename parameter. Download(file string, filename ...string) error + // If Done is not yet closed, Err returns nil. + // If Done is closed, Err returns a non-nil error explaining why: + // DeadlineExceeded if the context's deadline passed, + // or Canceled if the context was canceled for some other reason. + // After Err returns a non-nil error, successive calls to Err return the same error. + Err() error // Request return the *fasthttp.Request object // This allows you to use all fasthttp request methods // https://godoc.org/github.com/valyala/fasthttp#Request @@ -317,6 +328,9 @@ type Ctx interface { // Vary adds the given header field to the Vary response header. // This will append the header, if not already listed, otherwise leaves it listed in the current location. Vary(fields ...string) + // Value makes it possible to retrieve values (Locals) under keys scoped to the request + // and therefore available to all following routes that match the request. + Value(key any) any // Write appends p into response body. Write(p []byte) (int, error) // Writef appends f & a into response body writer. diff --git a/ctx_test.go b/ctx_test.go index 5040f4f8fc4..08113efdd1f 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -9,7 +9,6 @@ import ( "bytes" "compress/gzip" "compress/zlib" - "context" "crypto/tls" "embed" "encoding/hex" @@ -17,7 +16,6 @@ import ( "errors" "fmt" "io" - "math" "mime/multipart" "net" "net/http/httptest" @@ -882,76 +880,6 @@ func Test_Ctx_RequestCtx(t *testing.T) { require.Equal(t, "*fasthttp.RequestCtx", fmt.Sprintf("%T", c.RequestCtx())) } -// go test -run Test_Ctx_Context -func Test_Ctx_Context(t *testing.T) { - t.Parallel() - app := New() - c := app.AcquireCtx(&fasthttp.RequestCtx{}) - - t.Run("Nil_Context", func(t *testing.T) { - t.Parallel() - ctx := c.Context() - require.Equal(t, ctx, context.Background()) - }) - t.Run("ValueContext", func(t *testing.T) { - t.Parallel() - testKey := struct{}{} - testValue := "Test Value" - ctx := context.WithValue(context.Background(), testKey, testValue) //nolint:staticcheck // not needed for tests - require.Equal(t, testValue, ctx.Value(testKey)) - }) -} - -// go test -run Test_Ctx_SetContext -func Test_Ctx_SetContext(t *testing.T) { - t.Parallel() - app := New() - c := app.AcquireCtx(&fasthttp.RequestCtx{}) - - testKey := struct{}{} - testValue := "Test Value" - ctx := context.WithValue(context.Background(), testKey, testValue) //nolint:staticcheck // not needed for tests - c.SetContext(ctx) - require.Equal(t, testValue, c.Context().Value(testKey)) -} - -// go test -run Test_Ctx_Context_Multiple_Requests -func Test_Ctx_Context_Multiple_Requests(t *testing.T) { - t.Parallel() - testKey := struct{}{} - testValue := "foobar-value" - - app := New() - app.Get("/", func(c Ctx) error { - ctx := c.Context() - - if ctx.Value(testKey) != nil { - return c.SendStatus(StatusInternalServerError) - } - - input := utils.CopyString(Query(c, "input", "NO_VALUE")) - ctx = context.WithValue(ctx, testKey, fmt.Sprintf("%s_%s", testValue, input)) //nolint:staticcheck // not needed for tests - c.SetContext(ctx) - - return c.Status(StatusOK).SendString(fmt.Sprintf("resp_%s_returned", input)) - }) - - // Consecutive Requests - for i := 1; i <= 10; i++ { - t.Run(fmt.Sprintf("request_%d", i), func(t *testing.T) { - t.Parallel() - resp, err := app.Test(httptest.NewRequest(MethodGet, fmt.Sprintf("/?input=%d", i), nil)) - - require.NoError(t, err, "Unexpected error from response") - require.Equal(t, StatusOK, resp.StatusCode, "context.Context returned from c.Context() is reused") - - b, err := io.ReadAll(resp.Body) - require.NoError(t, err, "Unexpected error from reading response body") - require.Equal(t, fmt.Sprintf("resp_%d_returned", i), string(b), "response text incorrect") - }) - } -} - // go test -run Test_Ctx_Cookie func Test_Ctx_Cookie(t *testing.T) { t.Parallel() @@ -1472,10 +1400,10 @@ func Test_Ctx_Binders(t *testing.T) { type TestStruct struct { Name string - NameWithDefault string `json:"name2" xml:"Name2" form:"name2" cookie:"name2" query:"name2" params:"name2" header:"Name2"` + NameWithDefault string `json:"name2" xml:"Name2" form:"name2" cookie:"name2" query:"name2" uri:"name2" header:"Name2"` TestEmbeddedStruct Class int - ClassWithDefault int `json:"class2" xml:"Class2" form:"class2" cookie:"class2" query:"class2" params:"class2" header:"Class2"` + ClassWithDefault int `json:"class2" xml:"Class2" form:"class2" cookie:"class2" query:"class2" uri:"class2" header:"Class2"` } withValues := func(t *testing.T, actionFn func(c Ctx, testStruct *TestStruct) error) { @@ -1541,16 +1469,26 @@ func Test_Ctx_Binders(t *testing.T) { return c.Bind().Query(testStruct) }) }) + t.Run("URI", func(t *testing.T) { - t.Skip("URI is not ready for v3") - //nolint:gocritic // TODO: uncomment - // t.Parallel() - // withValues(t, func(c Ctx, testStruct *TestStruct) error { - // c.Route().Params = []string{"name", "name2", "class", "class2"} - // c.Params().value = [30]string{"foo", "bar", "111", "222"} - // return c.Bind().URI(testStruct) - // }) + t.Parallel() + + c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed + defer app.ReleaseCtx(c) + + c.route = &Route{Params: []string{"name", "name2", "class", "class2"}} + c.values = [maxParams]string{"foo", "bar", "111", "222"} + + testStruct := new(TestStruct) + + require.NoError(t, c.Bind().URI(testStruct)) + require.Equal(t, "foo", testStruct.Name) + require.Equal(t, 111, testStruct.Class) + require.Equal(t, "bar", testStruct.NameWithDefault) + require.Equal(t, 222, testStruct.ClassWithDefault) + require.Nil(t, testStruct.TestEmbeddedStruct.Names) }) + t.Run("ReqHeader", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { @@ -2258,6 +2196,73 @@ func Test_Ctx_Locals(t *testing.T) { require.Equal(t, StatusOK, resp.StatusCode, "Status code") } +// go test -run Test_Ctx_Deadline +func Test_Ctx_Deadline(t *testing.T) { + t.Parallel() + app := New() + app.Use(func(c Ctx) error { + return c.Next() + }) + app.Get("/test", func(c Ctx) error { + deadline, ok := c.Deadline() + require.Equal(t, time.Time{}, deadline) + require.False(t, ok) + return nil + }) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") +} + +// go test -run Test_Ctx_Done +func Test_Ctx_Done(t *testing.T) { + t.Parallel() + app := New() + app.Use(func(c Ctx) error { + return c.Next() + }) + app.Get("/test", func(c Ctx) error { + require.Equal(t, (<-chan struct{})(nil), c.Done()) + return nil + }) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") +} + +// go test -run Test_Ctx_Err +func Test_Ctx_Err(t *testing.T) { + t.Parallel() + app := New() + app.Use(func(c Ctx) error { + return c.Next() + }) + app.Get("/test", func(c Ctx) error { + require.NoError(t, c.Err()) + return nil + }) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") +} + +// go test -run Test_Ctx_Value +func Test_Ctx_Value(t *testing.T) { + t.Parallel() + app := New() + app.Use(func(c Ctx) error { + c.Locals("john", "doe") + return c.Next() + }) + app.Get("/test", func(c Ctx) error { + require.Equal(t, "doe", c.Value("john")) + return nil + }) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, StatusOK, resp.StatusCode, "Status code") +} + // go test -run Test_Ctx_Locals_Generic func Test_Ctx_Locals_Generic(t *testing.T) { t.Parallel() @@ -2494,6 +2499,7 @@ func Test_Ctx_Params(t *testing.T) { }) app.Get("/test4/:optional?", func(c Ctx) error { require.Equal(t, "", c.Params("optional")) + require.Equal(t, "default", Params(c, "optional", "default")) return nil }) app.Get("/test5/:id/:Id", func(c Ctx) error { @@ -5215,6 +5221,18 @@ func Test_Ctx_GetReqHeaders(t *testing.T) { }, c.GetReqHeaders()) } +func Test_Ctx_Set_SanitizeHeaderValue(t *testing.T) { + t.Parallel() + + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + + c.Set("X-Test", "foo\r\nbar: bad") + + headerVal := string(c.Response().Header.Peek("X-Test")) + require.Equal(t, "foo bar: bad", headerVal) +} + func Benchmark_Ctx_GetReqHeaders(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) @@ -5238,765 +5256,121 @@ func Benchmark_Ctx_GetReqHeaders(b *testing.B) { }, headers) } -// go test -run Test_GenericParseTypeInts -func Test_GenericParseTypeInts(t *testing.T) { +// go test -run Test_Ctx_Drop -v +func Test_Ctx_Drop(t *testing.T) { t.Parallel() - type genericTypes[v GenericType] struct { - value v - str string - } - - ints := []genericTypes[int]{ - { - value: 0, - str: "0", - }, - { - value: 1, - str: "1", - }, - { - value: 2, - str: "2", - }, - { - value: 3, - str: "3", - }, - { - value: 4, - str: "4", - }, - { - value: 2147483647, - str: "2147483647", - }, - { - value: -2147483648, - str: "-2147483648", - }, - { - value: -1, - str: "-1", - }, - } - for _, test := range ints { - var v int - tt := test - t.Run("test_genericParseTypeInts", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[int](tt.str, v)) - }) - } -} + app := New() -// go test -run Test_GenericParseTypeInt8s -func Test_GenericParseTypeInt8s(t *testing.T) { - t.Parallel() + // Handler that calls Drop + app.Get("/block-me", func(c Ctx) error { + return c.Drop() + }) - type genericTypes[v GenericType] struct { - value v - str string - } + // Additional handler that just calls return + app.Get("/no-response", func(_ Ctx) error { + return nil + }) - int8s := []genericTypes[int8]{ - { - value: int8(0), - str: "0", - }, - { - value: int8(1), - str: "1", - }, - { - value: int8(2), - str: "2", - }, - { - value: int8(3), - str: "3", - }, - { - value: int8(4), - str: "4", - }, - { - value: int8(math.MaxInt8), - str: strconv.Itoa(math.MaxInt8), - }, - { - value: int8(math.MinInt8), - str: strconv.Itoa(math.MinInt8), - }, - } + // Test the Drop method + resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", nil)) + require.ErrorIs(t, err, ErrTestGotEmptyResponse) + require.Nil(t, resp) - for _, test := range int8s { - var v int8 - tt := test - t.Run("test_genericParseTypeInt8s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[int8](tt.str, v)) - }) - } + // Test the no-response handler + resp, err = app.Test(httptest.NewRequest(MethodGet, "/no-response", nil)) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, StatusOK, resp.StatusCode) + require.Equal(t, "0", resp.Header.Get("Content-Length")) } -// go test -run Test_GenericParseTypeInt16s -func Test_GenericParseTypeInt16s(t *testing.T) { +// go test -run Test_Ctx_DropWithMiddleware -v +func Test_Ctx_DropWithMiddleware(t *testing.T) { t.Parallel() - type genericTypes[v GenericType] struct { - value v - str string - } - int16s := []genericTypes[int16]{ - { - value: int16(0), - str: "0", - }, - { - value: int16(1), - str: "1", - }, - { - value: int16(2), - str: "2", - }, - { - value: int16(3), - str: "3", - }, - { - value: int16(4), - str: "4", - }, - { - value: int16(math.MaxInt16), - str: strconv.Itoa(math.MaxInt16), - }, - { - value: int16(math.MinInt16), - str: strconv.Itoa(math.MinInt16), - }, - } + app := New() - for _, test := range int16s { - var v int16 - tt := test - t.Run("test_genericParseTypeInt16s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[int16](tt.str, v)) - }) - } + // Middleware that calls Drop + app.Use(func(c Ctx) error { + err := c.Next() + c.Set("X-Test", "test") + return err + }) + + // Handler that calls Drop + app.Get("/block-me", func(c Ctx) error { + return c.Drop() + }) + + // Test the Drop method + resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", nil)) + require.ErrorIs(t, err, ErrTestGotEmptyResponse) + require.Nil(t, resp) } -// go test -run Test_GenericParseTypeInt32s -func Test_GenericParseTypeInt32s(t *testing.T) { - t.Parallel() - type genericTypes[v GenericType] struct { - value v - str string - } +// go test -run Test_Ctx_End +func Test_Ctx_End(t *testing.T) { + app := New() - int32s := []genericTypes[int32]{ - { - value: int32(0), - str: "0", - }, - { - value: int32(1), - str: "1", - }, - { - value: int32(2), - str: "2", - }, - { - value: int32(3), - str: "3", - }, - { - value: int32(4), - str: "4", - }, - { - value: int32(math.MaxInt32), - str: strconv.Itoa(math.MaxInt32), - }, - { - value: int32(math.MinInt32), - str: strconv.Itoa(math.MinInt32), - }, - } + app.Get("/", func(c Ctx) error { + c.SendString("Hello, World!") //nolint:errcheck // unnecessary to check error + return c.End() + }) - for _, test := range int32s { - var v int32 - tt := test - t.Run("test_genericParseTypeInt32s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[int32](tt.str, v)) - }) - } + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "io.ReadAll(resp.Body)") + require.Equal(t, "Hello, World!", string(body)) } -// go test -run Test_GenericParseTypeInt64s -func Test_GenericParseTypeInt64s(t *testing.T) { - t.Parallel() - type genericTypes[v GenericType] struct { - value v - str string - } +// go test -run Test_Ctx_End_after_timeout +func Test_Ctx_End_after_timeout(t *testing.T) { + app := New() - int64s := []genericTypes[int64]{ - { - value: int64(0), - str: "0", - }, - { - value: int64(1), - str: "1", - }, - { - value: int64(2), - str: "2", - }, - { - value: int64(3), - str: "3", - }, - { - value: int64(4), - str: "4", - }, - { - value: int64(math.MaxInt64), - str: strconv.Itoa(math.MaxInt64), - }, - { - value: int64(math.MinInt64), - str: strconv.Itoa(math.MinInt64), - }, - } - - for _, test := range int64s { - var v int64 - tt := test - t.Run("test_genericParseTypeInt64s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[int64](tt.str, v)) - }) - } -} - -// go test -run Test_GenericParseTypeUints -func Test_GenericParseTypeUints(t *testing.T) { - t.Parallel() - type genericTypes[v GenericType] struct { - value v - str string - } - - uints := []genericTypes[uint]{ - { - value: uint(0), - str: "0", - }, - { - value: uint(1), - str: "1", - }, - { - value: uint(2), - str: "2", - }, - { - value: uint(3), - str: "3", - }, - { - value: uint(4), - str: "4", - }, - { - value: ^uint(0), - str: strconv.FormatUint(uint64(^uint(0)), 10), - }, - } - - for _, test := range uints { - var v uint - tt := test - t.Run("test_genericParseTypeUints", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[uint](tt.str, v)) - }) - } -} - -// go test -run Test_GenericParseTypeUints -func Test_GenericParseTypeUint8s(t *testing.T) { - t.Parallel() - type genericTypes[v GenericType] struct { - value v - str string - } - - uint8s := []genericTypes[uint8]{ - { - value: uint8(0), - str: "0", - }, - { - value: uint8(1), - str: "1", - }, - { - value: uint8(2), - str: "2", - }, - { - value: uint8(3), - str: "3", - }, - { - value: uint8(4), - str: "4", - }, - { - value: uint8(math.MaxUint8), - str: strconv.Itoa(math.MaxUint8), - }, - } - - for _, test := range uint8s { - var v uint8 - tt := test - t.Run("test_genericParseTypeUint8s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[uint8](tt.str, v)) - }) - } -} - -// go test -run Test_GenericParseTypeUint16s -func Test_GenericParseTypeUint16s(t *testing.T) { - t.Parallel() - - type genericTypes[v GenericType] struct { - value v - str string - } - - uint16s := []genericTypes[uint16]{ - { - value: uint16(0), - str: "0", - }, - { - value: uint16(1), - str: "1", - }, - { - value: uint16(2), - str: "2", - }, - { - value: uint16(3), - str: "3", - }, - { - value: uint16(4), - str: "4", - }, - { - value: uint16(math.MaxUint16), - str: strconv.Itoa(math.MaxUint16), - }, - } - - for _, test := range uint16s { - var v uint16 - tt := test - t.Run("test_genericParseTypeUint16s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[uint16](tt.str, v)) - }) - } -} - -// go test -run Test_GenericParseTypeUint32s -func Test_GenericParseTypeUint32s(t *testing.T) { - t.Parallel() - - type genericTypes[v GenericType] struct { - value v - str string - } - - uint32s := []genericTypes[uint32]{ - { - value: uint32(0), - str: "0", - }, - { - value: uint32(1), - str: "1", - }, - { - value: uint32(2), - str: "2", - }, - { - value: uint32(3), - str: "3", - }, - { - value: uint32(4), - str: "4", - }, - { - value: uint32(math.MaxUint32), - str: strconv.Itoa(math.MaxUint32), - }, - } - - for _, test := range uint32s { - var v uint32 - tt := test - t.Run("test_genericParseTypeUint32s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[uint32](tt.str, v)) - }) - } -} - -// go test -run Test_GenericParseTypeUint64s -func Test_GenericParseTypeUint64s(t *testing.T) { - t.Parallel() - type genericTypes[v GenericType] struct { - value v - str string - } - - uint64s := []genericTypes[uint64]{ - { - value: uint64(0), - str: "0", - }, - { - value: uint64(1), - str: "1", - }, - { - value: uint64(2), - str: "2", - }, - { - value: uint64(3), - str: "3", - }, - { - value: uint64(4), - str: "4", - }, - } - - for _, test := range uint64s { - var v uint64 - tt := test - t.Run("test_genericParseTypeUint64s", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v)) - require.Equal(t, tt.value, genericParseType[uint64](tt.str, v)) - }) - } -} - -// go test -run Test_GenericParseTypeFloat32s -func Test_GenericParseTypeFloat32s(t *testing.T) { - t.Parallel() - - type genericTypes[v GenericType] struct { - value v - str string - } - - float32s := []genericTypes[float32]{ - { - value: float32(3.1415), - str: "3.1415", - }, - { - value: float32(1.234), - str: "1.234", - }, - { - value: float32(2), - str: "2", - }, - { - value: float32(3), - str: "3", - }, - } - - for _, test := range float32s { - var v float32 - tt := test - t.Run("test_genericParseTypeFloat32s", func(t *testing.T) { - t.Parallel() - require.InEpsilon(t, tt.value, genericParseType(tt.str, v), epsilon) - require.InEpsilon(t, tt.value, genericParseType[float32](tt.str, v), epsilon) - }) - } -} - -// go test -run Test_GenericParseTypeFloat64s -func Test_GenericParseTypeFloat64s(t *testing.T) { - t.Parallel() - - type genericTypes[v GenericType] struct { - value v - str string - } - - float64s := []genericTypes[float64]{ - { - value: float64(3.1415), - str: "3.1415", - }, - { - value: float64(1.234), - str: "1.234", - }, - { - value: float64(2), - str: "2", - }, - { - value: float64(3), - str: "3", - }, - } - - for _, test := range float64s { - var v float64 - tt := test - t.Run("test_genericParseTypeFloat64s", func(t *testing.T) { - t.Parallel() - require.InEpsilon(t, tt.value, genericParseType(tt.str, v), epsilon) - require.InEpsilon(t, tt.value, genericParseType[float64](tt.str, v), epsilon) - }) - } -} - -// go test -run Test_GenericParseTypeArrayBytes -func Test_GenericParseTypeArrayBytes(t *testing.T) { - t.Parallel() - - type genericTypes[v GenericType] struct { - value v - str string - } - - arrBytes := []genericTypes[[]byte]{ - { - value: []byte("alex"), - str: "alex", - }, - { - value: []byte("32.23"), - str: "32.23", - }, - { - value: []byte(nil), - str: "", - }, - { - value: []byte("john"), - str: "john", - }, - } - - for _, test := range arrBytes { - var v []byte - tt := test - t.Run("test_genericParseTypeArrayBytes", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.value, genericParseType(tt.str, v, []byte(nil))) - require.Equal(t, tt.value, genericParseType[[]byte](tt.str, v, []byte(nil))) - }) - } -} - -// go test -run Test_GenericParseTypeBoolean -func Test_GenericParseTypeBoolean(t *testing.T) { - t.Parallel() - - type genericTypes[v GenericType] struct { - value v - str string - } - - bools := []genericTypes[bool]{ - { - str: "True", - value: true, - }, - { - str: "False", - value: false, - }, - { - str: "true", - value: true, - }, - { - str: "false", - value: false, - }, - } + // Early flushing handler + app.Get("/", func(c Ctx) error { + time.Sleep(2 * time.Second) + return c.End() + }) - for _, test := range bools { - var v bool - tt := test - t.Run("test_genericParseTypeBoolean", func(t *testing.T) { - t.Parallel() - if tt.value { - require.True(t, genericParseType(tt.str, v)) - require.True(t, genericParseType[bool](tt.str, v)) - } else { - require.False(t, genericParseType(tt.str, v)) - require.False(t, genericParseType[bool](tt.str, v)) - } - }) - } + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + require.Nil(t, resp) } -// go test -run Test_Ctx_Drop -v -func Test_Ctx_Drop(t *testing.T) { - t.Parallel() - +// go test -run Test_Ctx_End_with_drop_middleware +func Test_Ctx_End_with_drop_middleware(t *testing.T) { app := New() - // Handler that calls Drop - app.Get("/block-me", func(c Ctx) error { + // Middleware that will drop connections + // that persist after c.Next() + app.Use(func(c Ctx) error { + c.Next() //nolint:errcheck // unnecessary to check error return c.Drop() }) - // Additional handler that just calls return - app.Get("/no-response", func(_ Ctx) error { - return nil + // Early flushing handler + app.Get("/", func(c Ctx) error { + c.SendStatus(StatusOK) //nolint:errcheck // unnecessary to check error + return c.End() }) - // Test the Drop method - resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", nil)) - require.ErrorIs(t, err, ErrTestGotEmptyResponse) - require.Nil(t, resp) - - // Test the no-response handler - resp, err = app.Test(httptest.NewRequest(MethodGet, "/no-response", nil)) + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, StatusOK, resp.StatusCode) - require.Equal(t, "0", resp.Header.Get("Content-Length")) } -// go test -run Test_Ctx_DropWithMiddleware -v -func Test_Ctx_DropWithMiddleware(t *testing.T) { - t.Parallel() - +// go test -run Test_Ctx_End_after_drop +func Test_Ctx_End_after_drop(t *testing.T) { app := New() - // Middleware that calls Drop - app.Use(func(c Ctx) error { - err := c.Next() - c.Set("X-Test", "test") - return err - }) - - // Handler that calls Drop - app.Get("/block-me", func(c Ctx) error { - return c.Drop() - }) - - // Test the Drop method - resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", nil)) - require.ErrorIs(t, err, ErrTestGotEmptyResponse) - require.Nil(t, resp) -} - -// go test -run Test_Ctx_End -func Test_Ctx_End(t *testing.T) { - app := New() - - app.Get("/", func(c Ctx) error { - c.SendString("Hello, World!") //nolint:errcheck // unnecessary to check error - return c.End() - }) - - resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, StatusOK, resp.StatusCode) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err, "io.ReadAll(resp.Body)") - require.Equal(t, "Hello, World!", string(body)) -} - -// go test -run Test_Ctx_End_after_timeout -func Test_Ctx_End_after_timeout(t *testing.T) { - app := New() - - // Early flushing handler - app.Get("/", func(c Ctx) error { - time.Sleep(2 * time.Second) - return c.End() - }) - - resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) - require.ErrorIs(t, err, os.ErrDeadlineExceeded) - require.Nil(t, resp) -} - -// go test -run Test_Ctx_End_with_drop_middleware -func Test_Ctx_End_with_drop_middleware(t *testing.T) { - app := New() - - // Middleware that will drop connections - // that persist after c.Next() - app.Use(func(c Ctx) error { - c.Next() //nolint:errcheck // unnecessary to check error - return c.Drop() - }) - - // Early flushing handler - app.Get("/", func(c Ctx) error { - c.SendStatus(StatusOK) //nolint:errcheck // unnecessary to check error - return c.End() - }) - - resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, StatusOK, resp.StatusCode) -} - -// go test -run Test_Ctx_End_after_drop -func Test_Ctx_End_after_drop(t *testing.T) { - app := New() - - // Middleware that ends the request - // after c.Next() + // Middleware that ends the request + // after c.Next() app.Use(func(c Ctx) error { c.Next() //nolint:errcheck // unnecessary to check error return c.End() @@ -6012,561 +5386,6 @@ func Test_Ctx_End_after_drop(t *testing.T) { require.Nil(t, resp) } -// go test -run Test_GenericParseTypeString -func Test_GenericParseTypeString(t *testing.T) { - t.Parallel() - - tests := []string{"john", "doe", "hello", "fiber"} - - for _, test := range tests { - var v string - tt := test - t.Run("test_genericParseTypeString", func(t *testing.T) { - t.Parallel() - require.Equal(t, tt, genericParseType(tt, v)) - require.Equal(t, tt, genericParseType[string](tt, v)) - }) - } -} - -// go test -v -run=^$ -bench=Benchmark_GenericParseTypeInts -benchmem -count=4 -func Benchmark_GenericParseTypeInts(b *testing.B) { - type genericTypes[v GenericType] struct { - value v - str string - } - - ints := []genericTypes[int]{ - { - value: 0, - str: "0", - }, - { - value: 1, - str: "1", - }, - { - value: 2, - str: "2", - }, - { - value: 3, - str: "3", - }, - { - value: 4, - str: "4", - }, - } - - int8s := []genericTypes[int8]{ - { - value: int8(0), - str: "0", - }, - { - value: int8(1), - str: "1", - }, - { - value: int8(2), - str: "2", - }, - { - value: int8(3), - str: "3", - }, - { - value: int8(4), - str: "4", - }, - } - - int16s := []genericTypes[int16]{ - { - value: int16(0), - str: "0", - }, - { - value: int16(1), - str: "1", - }, - { - value: int16(2), - str: "2", - }, - { - value: int16(3), - str: "3", - }, - { - value: int16(4), - str: "4", - }, - } - - int32s := []genericTypes[int32]{ - { - value: int32(0), - str: "0", - }, - { - value: int32(1), - str: "1", - }, - { - value: int32(2), - str: "2", - }, - { - value: int32(3), - str: "3", - }, - { - value: int32(4), - str: "4", - }, - } - - int64s := []genericTypes[int64]{ - { - value: int64(0), - str: "0", - }, - { - value: int64(1), - str: "1", - }, - { - value: int64(2), - str: "2", - }, - { - value: int64(3), - str: "3", - }, - { - value: int64(4), - str: "4", - }, - } - - for _, test := range ints { - b.Run("bench_genericParseTypeInts", func(b *testing.B) { - var res int - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range int8s { - b.Run("benchmark_genericParseTypeInt8s", func(b *testing.B) { - var res int8 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range int16s { - b.Run("benchmark_genericParseTypeInt16s", func(b *testing.B) { - var res int16 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range int32s { - b.Run("benchmark_genericParseType32Ints", func(b *testing.B) { - var res int32 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range int64s { - b.Run("benchmark_genericParseTypeInt64s", func(b *testing.B) { - var res int64 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } -} - -// go test -v -run=^$ -bench=Benchmark_GenericParseTypeUints -benchmem -count=4 -func Benchmark_GenericParseTypeUints(b *testing.B) { - type genericTypes[v GenericType] struct { - value v - str string - } - - uints := []genericTypes[uint]{ - { - value: uint(0), - str: "0", - }, - { - value: uint(1), - str: "1", - }, - { - value: uint(2), - str: "2", - }, - { - value: uint(3), - str: "3", - }, - { - value: uint(4), - str: "4", - }, - } - - uint8s := []genericTypes[uint8]{ - { - value: uint8(0), - str: "0", - }, - { - value: uint8(1), - str: "1", - }, - { - value: uint8(2), - str: "2", - }, - { - value: uint8(3), - str: "3", - }, - { - value: uint8(4), - str: "4", - }, - } - - uint16s := []genericTypes[uint16]{ - { - value: uint16(0), - str: "0", - }, - { - value: uint16(1), - str: "1", - }, - { - value: uint16(2), - str: "2", - }, - { - value: uint16(3), - str: "3", - }, - { - value: uint16(4), - str: "4", - }, - } - - uint32s := []genericTypes[uint32]{ - { - value: uint32(0), - str: "0", - }, - { - value: uint32(1), - str: "1", - }, - { - value: uint32(2), - str: "2", - }, - { - value: uint32(3), - str: "3", - }, - { - value: uint32(4), - str: "4", - }, - } - - uint64s := []genericTypes[uint64]{ - { - value: uint64(0), - str: "0", - }, - { - value: uint64(1), - str: "1", - }, - { - value: uint64(2), - str: "2", - }, - { - value: uint64(3), - str: "3", - }, - { - value: uint64(4), - str: "4", - }, - } - - for _, test := range uints { - b.Run("benchamark_genericParseTypeUints", func(b *testing.B) { - var res uint - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range uint8s { - b.Run("benchamark_genericParseTypeUint8s", func(b *testing.B) { - var res uint8 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range uint16s { - b.Run("benchamark_genericParseTypeUint16s", func(b *testing.B) { - var res uint16 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range uint32s { - b.Run("benchamark_genericParseTypeUint32s", func(b *testing.B) { - var res uint32 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } - - for _, test := range uint64s { - b.Run("benchamark_genericParseTypeUint64s", func(b *testing.B) { - var res uint64 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.Equal(b, test.value, res) - }) - } -} - -// go test -v -run=^$ -bench=Benchmark_GenericParseTypeFloats -benchmem -count=4 -func Benchmark_GenericParseTypeFloats(b *testing.B) { - type genericTypes[v GenericType] struct { - value v - str string - } - - float32s := []genericTypes[float32]{ - { - value: float32(3.1415), - str: "3.1415", - }, - { - value: float32(1.234), - str: "1.234", - }, - { - value: float32(2), - str: "2", - }, - { - value: float32(3), - str: "3", - }, - } - - float64s := []genericTypes[float64]{ - { - value: float64(3.1415), - str: "3.1415", - }, - { - value: float64(1.234), - str: "1.234", - }, - { - value: float64(2), - str: "2", - }, - { - value: float64(3), - str: "3", - }, - } - - for _, test := range float32s { - b.Run("benchmark_genericParseTypeFloat32s", func(b *testing.B) { - var res float32 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.InEpsilon(b, test.value, res, epsilon) - }) - } - - for _, test := range float64s { - b.Run("benchmark_genericParseTypeFloat32s", func(b *testing.B) { - var res float64 - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - require.InEpsilon(b, test.value, res, epsilon) - }) - } -} - -// go test -v -run=^$ -bench=Benchmark_GenericParseTypeArrayBytes -benchmem -count=4 -func Benchmark_GenericParseTypeArrayBytes(b *testing.B) { - type genericTypes[v GenericType] struct { - value v - str string - } - - arrBytes := []genericTypes[[]byte]{ - { - value: []byte("alex"), - str: "alex", - }, - { - value: []byte("32.23"), - str: "32.23", - }, - { - value: []byte(nil), - str: "", - }, - { - value: []byte("john"), - str: "john", - }, - } - - for _, test := range arrBytes { - b.Run("Benchmark_GenericParseTypeArrayBytes", func(b *testing.B) { - var res []byte - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res, []byte(nil)) - } - require.Equal(b, test.value, res) - }) - } -} - -// go test -v -run=^$ -bench=Benchmark_GenericParseTypeBoolean -benchmem -count=4 -func Benchmark_GenericParseTypeBoolean(b *testing.B) { - type genericTypes[v GenericType] struct { - value v - str string - } - - bools := []genericTypes[bool]{ - { - str: "True", - value: true, - }, - { - str: "False", - value: false, - }, - { - str: "true", - value: true, - }, - { - str: "false", - value: false, - }, - } - - for _, test := range bools { - b.Run("Benchmark_GenericParseTypeBoolean", func(b *testing.B) { - var res bool - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test.str, res) - } - if test.value { - require.True(b, res) - } else { - require.False(b, res) - } - }) - } -} - -// go test -v -run=^$ -bench=Benchmark_GenericParseTypeString -benchmem -count=4 -func Benchmark_GenericParseTypeString(b *testing.B) { - tests := []string{"john", "doe", "hello", "fiber"} - - b.ReportAllocs() - b.ResetTimer() - for _, test := range tests { - b.Run("benchmark_genericParseTypeString", func(b *testing.B) { - var res string - b.ReportAllocs() - b.ResetTimer() - for n := 0; n < b.N; n++ { - res = genericParseType(test, res) - } - - require.Equal(b, test, res) - }) - } -} - // go test -v -run=^$ -bench=Benchmark_Ctx_IsProxyTrusted -benchmem -count=4 func Benchmark_Ctx_IsProxyTrusted(b *testing.B) { // Scenario without trusted proxy check diff --git a/docs/api/ctx.md b/docs/api/ctx.md index 76bf8ca67b0..438ba591e1a 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -43,18 +43,34 @@ app.Post("/", func(c fiber.Ctx) error { ### Context -`Context` returns a context implementation that was set by the user earlier or returns a non-nil, empty context if it was not set earlier. +`Context` implements `context.Context`. However due to [current limitations in how fasthttp](https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945) works, `Deadline()`, `Done()` and `Err()` operate as a nop. ```go title="Signature" -func (c fiber.Ctx) Context() context.Context +func (c fiber.Ctx) Deadline() (deadline time.Time, ok bool) +func (c fiber.Ctx) Done() <-chan struct{} +func (c fiber.Ctx) Err() error +func (c fiber.Ctx) Value(key any) any ``` ```go title="Example" -app.Get("/", func(c fiber.Ctx) error { - ctx := c.Context() - // ctx is context implementation set by user +func doSomething(ctx context.Context) { // ... +} + +app.Get("/", func(c fiber.Ctx) error { + doSomething(c) +}) +``` + +#### Value + +Value can be used to retrieve [**`Locals`**](#locals). + +```go title="Example" +app.Get("/", func(c fiber.Ctx) error { + c.Locals(userKey, "admin") + user := c.Value(userKey) // returns "admin" }) ``` @@ -369,24 +385,6 @@ func MyMiddleware() fiber.Handler { } ``` -### SetContext - -Sets the user-specified implementation for the `context.Context` interface. - -```go title="Signature" -func (c fiber.Ctx) SetContext(ctx context.Context) -``` - -```go title="Example" -app.Get("/", func(c fiber.Ctx) error { - ctx := context.Background() - c.SetContext(ctx) - // Here ctx could be any context implementation - - // ... -}) -``` - ### String Returns a unique string representation of the context. diff --git a/docs/guide/routing.md b/docs/guide/routing.md index 8953dc08ac1..81d63f0fb2b 100644 --- a/docs/guide/routing.md +++ b/docs/guide/routing.md @@ -48,7 +48,7 @@ So please be careful to write routes with variable parameters after the routes t ## Parameters -Route parameters are dynamic elements in the route, which are **named** or **not named segments**. This segments that are used to capture the values specified at their position in the URL. The obtained values can be retrieved using the [Params](https://fiber.wiki/context#params) function, with the name of the route parameter specified in the path as their respective keys or for unnamed parameters the character\(\*, +\) and the counter of this. +Route parameters are dynamic elements in the route, which are **named** or **not named segments**. These segments are used to capture the values specified at their position in the URL. The obtained values can be retrieved using the [Params](https://fiber.wiki/context#params) function, with the name of the route parameter specified in the path as their respective keys or, for unnamed parameters, the character\(\*, +\) and the counter of this. The characters :, +, and \* are characters that introduce a parameter. @@ -56,7 +56,7 @@ Greedy parameters are indicated by wildcard\(\*\) or plus\(+\) signs. The routing also offers the possibility to use optional parameters, for the named parameters these are marked with a final "?", unlike the plus sign which is not optional, you can use the wildcard character for a parameter range which is optional and greedy. -### Example of define routes with route parameters +### Example of defining routes with route parameters ```go // Parameters diff --git a/docs/middleware/proxy.md b/docs/middleware/proxy.md index 8404efe2d86..a47206b74ce 100644 --- a/docs/middleware/proxy.md +++ b/docs/middleware/proxy.md @@ -43,15 +43,14 @@ import ( After you initiate your Fiber app, you can use the following possibilities: ```go -// if target https site uses a self-signed certificate, you should -// call WithTLSConfig before Do and Forward -proxy.WithTLSConfig(&tls.Config{ - InsecureSkipVerify: true, -}) // if you need to use global self-custom client, you should use proxy.WithClient. proxy.WithClient(&fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, + // if target https site uses a self-signed certificate, you should + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, }) // Forward to url @@ -164,7 +163,7 @@ app.Use(proxy.Balancer(proxy.Config{ | Timeout | `time.Duration` | Timeout is the request timeout used when calling the proxy client. | 1 second | | ReadBufferSize | `int` | Per-connection buffer size for requests' reading. This also limits the maximum header size. Increase this buffer if your clients send multi-KB RequestURIs and/or multi-KB headers (for example, BIG cookies). | (Not specified) | | WriteBufferSize | `int` | Per-connection buffer size for responses' writing. | (Not specified) | -| TlsConfig | `*tls.Config` (or `*fasthttp.TLSConfig` in v3) | TLS config for the HTTP client. | `nil` | +| TLSConfig | `*tls.Config` (or `*fasthttp.TLSConfig` in v3) | TLS config for the HTTP client. | `nil` | | DialDualStack | `bool` | Client will attempt to connect to both IPv4 and IPv6 host addresses if set to true. | `false` | | Client | `*fasthttp.LBClient` | Client is a custom client when client config is complex. | `nil` | diff --git a/docs/whats_new.md b/docs/whats_new.md index f1ff0239ebe..5ee03bde128 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -34,11 +34,12 @@ Here's a quick overview of the changes in Fiber `v3`: - [CSRF](#csrf) - [Compression](#compression) - [EncryptCookie](#encryptcookie) - - [Session](#session) - - [Logger](#logger) - [Filesystem](#filesystem) - - [Monitor](#monitor) - [Healthcheck](#healthcheck) + - [Logger](#logger) + - [Monitor](#monitor) + - [Proxy](#proxy) + - [Session](#session) - [🔌 Addons](#-addons) - [📋 Migration guide](#-migration-guide) @@ -74,6 +75,8 @@ We have made several changes to the Fiber app, including: - **ListenTLSWithCertificate**: Use `app.Listen()` with `tls.Config`. - **ListenMutualTLS**: Use `app.Listen()` with `tls.Config`. - **ListenMutualTLSWithCertificate**: Use `app.Listen()` with `tls.Config`. +- **Context()**: Use `Ctx` instead, it follow the `context.Context` interface +- **SetContext()**: Use `Ctx` instead, it follow the `context.Context` interface ### Method Changes @@ -393,10 +396,14 @@ testConfig := fiber.TestConfig{ ### New Features - Cookie now allows Partitioned cookies for [CHIPS](https://developers.google.com/privacy-sandbox/3pcd/chips) support. CHIPS (Cookies Having Independent Partitioned State) is a feature that improves privacy by allowing cookies to be partitioned by top-level site, mitigating cross-site tracking. +- Context now implements [context.Context](https://pkg.go.dev/context#Context). ### New Methods - **AutoFormat**: Similar to Express.js, automatically formats the response based on the request's `Accept` header. +- **Deadline**: For implementing `context.Context`. +- **Done**: For implementing `context.Context`. +- **Err**: For implementing `context.Context`. - **Host**: Similar to Express.js, returns the host name of the request. - **Port**: Similar to Express.js, returns the port number of the request. - **IsProxyTrusted**: Checks the trustworthiness of the remote IP. @@ -406,6 +413,7 @@ testConfig := fiber.TestConfig{ - **SendStreamWriter**: Sends a stream using a writer function. - **SendString**: Similar to Express.js, sends a string as the response. - **String**: Similar to Express.js, converts a value to a string. +- **Value**: For implementing `context.Context`. Returns request-scoped value from Locals. - **ViewBind**: Binds data to a view, replacing the old `Bind` method. - **CBOR**: Introducing [CBOR](https://cbor.io/) binary encoding format for both request & response body. CBOR is a binary data serialization format which is both compact and efficient, making it ideal for use in web applications. - **Drop**: Terminates the client connection silently without sending any HTTP headers or response body. This can be used for scenarios where you want to block certain requests without notifying the client, such as mitigating DDoS attacks or protecting sensitive endpoints from unauthorized access. @@ -988,21 +996,29 @@ We've added support for `zstd` compression on top of `gzip`, `deflate`, and `bro Added support for specifying Key length when using `encryptcookie.GenerateKey(length)`. This allows the user to generate keys compatible with `AES-128`, `AES-192`, and `AES-256` (Default). -### Session +### Filesystem -The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management. +We've decided to remove filesystem middleware to clear up the confusion between static and filesystem middleware. +Now, static middleware can do everything that filesystem middleware and static do. You can check out [static middleware](./middleware/static.md) or [migration guide](#-migration-guide) to see what has been changed. -#### Key Updates +### Healthcheck -- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewStore` or `NewWithStore` for custom store integration. +The Healthcheck middleware has been enhanced to support more than two routes, with default endpoints for liveliness, readiness, and startup checks. Here's a detailed breakdown of the changes and how to use the new features. -- **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`. +1. **Support for More Than Two Routes**: + - The updated middleware now supports multiple routes beyond the default liveliness and readiness endpoints. This allows for more granular health checks, such as startup probes. -- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. +2. **Default Endpoints**: + - Three default endpoints are now available: + - **Liveness**: `/livez` + - **Readiness**: `/readyz` + - **Startup**: `/startupz` + - These endpoints can be customized or replaced with user-defined routes. -- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. +3. **Simplified Configuration**: + - The configuration for each health check endpoint has been simplified. Each endpoint can be configured separately, allowing for more flexibility and readability. -For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). +Refer to the [healthcheck middleware migration guide](./middleware/healthcheck.md) or the [general migration guide](#-migration-guide) to review the changes. ### Logger @@ -1161,33 +1177,29 @@ app.Use(logger.New(logger.Config{ See more in [Logger](./middleware/logger.md#predefined-formats) -### Filesystem - -We've decided to remove filesystem middleware to clear up the confusion between static and filesystem middleware. -Now, static middleware can do everything that filesystem middleware and static do. You can check out [static middleware](./middleware/static.md) or [migration guide](#-migration-guide) to see what has been changed. - ### Monitor Monitor middleware is migrated to the [Contrib package](https://github.com/gofiber/contrib/tree/main/monitor) with [PR #1172](https://github.com/gofiber/contrib/pull/1172). -### Healthcheck +### Proxy -The Healthcheck middleware has been enhanced to support more than two routes, with default endpoints for liveliness, readiness, and startup checks. Here's a detailed breakdown of the changes and how to use the new features. +The proxy middleware has been updated to improve consistency with Go naming conventions. The `TlsConfig` field in the configuration struct has been renamed to `TLSConfig`. Additionally, the `WithTlsConfig` method has been removed; you should now configure TLS directly via the `TLSConfig` property within the `Config` struct. -1. **Support for More Than Two Routes**: - - The updated middleware now supports multiple routes beyond the default liveliness and readiness endpoints. This allows for more granular health checks, such as startup probes. +### Session -2. **Default Endpoints**: - - Three default endpoints are now available: - - **Liveness**: `/livez` - - **Readiness**: `/readyz` - - **Startup**: `/startupz` - - These endpoints can be customized or replaced with user-defined routes. +The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management. -3. **Simplified Configuration**: - - The configuration for each health check endpoint has been simplified. Each endpoint can be configured separately, allowing for more flexibility and readability. +#### Key Updates -Refer to the [healthcheck middleware migration guide](./middleware/healthcheck.md) or the [general migration guide](#-migration-guide) to review the changes. +- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewStore` or `NewWithStore` for custom store integration. + +- **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`. + +- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. + +- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. + +For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). ## 🔌 Addons @@ -1257,6 +1269,7 @@ func main() { - [Filesystem](#filesystem-1) - [Healthcheck](#healthcheck-1) - [Monitor](#monitor-1) + - [Proxy](#proxy-1) ### 🚀 App @@ -1854,3 +1867,29 @@ import "github.com/gofiber/contrib/monitor" app.Use("/metrics", monitor.New()) ``` + +#### Proxy + +In previous versions, TLS settings for the proxy middleware were set using the `WithTlsConfig` method. This method has been removed in favor of a more idiomatic configuration via the `TLSConfig` field in the `Config` struct. + +#### Before (v2 usage) + +```go +proxy.WithTlsConfig(&tls.Config{ + InsecureSkipVerify: true, +}) + +// Forward to url +app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) +``` + +#### After (v3 usage) + +```go +proxy.WithClient(&fasthttp.Client{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, +}) + +// Forward to url +app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) +``` diff --git a/go.mod b/go.mod index 2e43d8b9442..e5db4605568 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/tinylib/msgp v1.3.0 github.com/valyala/bytebufferpool v1.0.0 - github.com/valyala/fasthttp v1.60.0 + github.com/valyala/fasthttp v1.62.0 golang.org/x/crypto v0.38.0 ) @@ -23,7 +23,7 @@ require ( github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/net v0.38.0 // indirect + golang.org/x/net v0.40.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.25.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index bc27498737e..bed6c09b40f 100644 --- a/go.sum +++ b/go.sum @@ -26,16 +26,16 @@ github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= -github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= +github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/helpers.go b/helpers.go index 573aab3de38..1084a8cb412 100644 --- a/helpers.go +++ b/helpers.go @@ -567,9 +567,35 @@ func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool { } func parseAddr(raw string) (string, string) { //nolint:revive // Returns (host, port) - if i := strings.LastIndex(raw, ":"); i != -1 { - return raw[:i], raw[i+1:] + if raw == "" { + return "", "" } + + // Handle IPv6 addresses enclosed in brackets as defined by RFC 3986 + if strings.HasPrefix(raw, "[") { + if end := strings.IndexByte(raw, ']'); end != -1 { + host := raw[:end+1] // keep the closing ] + if len(raw) > end+1 && raw[end+1] == ':' { + return host, raw[end+2:] + } + return host, "" + } + } + + // Everything else with a colon + if i := strings.LastIndexByte(raw, ':'); i != -1 { + host, port := raw[:i], raw[i+1:] + + // If “host” still contains ':', we must have hit an un-bracketed IPv6 + // literal. In that form a port is impossible, so treat the whole thing + // as host. + if strings.Contains(host, ":") { + return raw, "" + } + return host, port + } + + // No colon, nothing to split return raw, "" } @@ -577,28 +603,30 @@ const noCacheValue = "no-cache" // isNoCache checks if the cacheControl header value is a `no-cache`. func isNoCache(cacheControl string) bool { - i := strings.Index(cacheControl, noCacheValue) - if i == -1 { - return false - } - - // Xno-cache - if i > 0 && !(cacheControl[i-1] == ' ' || cacheControl[i-1] == ',') { - return false - } - - // bla bla, no-cache - if i+len(noCacheValue) == len(cacheControl) { - return true - } - - // bla bla, no-cacheX - if cacheControl[i+len(noCacheValue)] != ',' { - return false + n := len(cacheControl) + ncLen := len(noCacheValue) + for i := 0; i < n; i++ { + if cacheControl[i] != 'n' { + continue + } + if i+ncLen > n { + return false + } + if cacheControl[i:i+ncLen] != noCacheValue { + continue + } + if i > 0 { + prev := cacheControl[i-1] + if prev != ' ' && prev != ',' { + continue + } + } + if i+ncLen == n || cacheControl[i+ncLen] == ',' { + return true + } } - // OK - return true + return false } var errTestConnClosed = errors.New("testConn is closed") @@ -727,90 +755,105 @@ func Convert[T any](value string, convertor func(string) (T, error), defaultValu return converted, nil } -// assertValueType asserts the type of the result to the type of the value -func assertValueType[V GenericType, T any](result T) V { - v, ok := any(result).(V) - if !ok { - panic(fmt.Errorf("failed to type-assert to %T", v)) - } - return v -} +var ( + errParsedEmptyString = errors.New("parsed result is empty string") + errParsedEmptyBytes = errors.New("parsed result is empty bytes") + errParsedType = errors.New("unsupported generic type") +) -func genericParseDefault[V GenericType](err error, parser func() V, defaultValue ...V) V { +func genericParseType[V GenericType](str string) (V, error) { var v V - if err != nil { - if len(defaultValue) > 0 { - return defaultValue[0] - } - return v - } - return parser() -} - -func genericParseInt[V GenericType](str string, bitSize int, parser func(int64) V, defaultValue ...V) V { - result, err := strconv.ParseInt(str, 10, bitSize) - return genericParseDefault[V](err, func() V { return parser(result) }, defaultValue...) -} - -func genericParseUint[V GenericType](str string, bitSize int, parser func(uint64) V, defaultValue ...V) V { - result, err := strconv.ParseUint(str, 10, bitSize) - return genericParseDefault[V](err, func() V { return parser(result) }, defaultValue...) -} - -func genericParseFloat[V GenericType](str string, bitSize int, parser func(float64) V, defaultValue ...V) V { - result, err := strconv.ParseFloat(str, bitSize) - return genericParseDefault[V](err, func() V { return parser(result) }, defaultValue...) -} - -func genericParseBool[V GenericType](str string, parser func(bool) V, defaultValue ...V) V { - result, err := strconv.ParseBool(str) - return genericParseDefault[V](err, func() V { return parser(result) }, defaultValue...) -} - -//nolint:gosec // Casting in this function is not a concern -func genericParseType[V GenericType](str string, v V, defaultValue ...V) V { switch any(v).(type) { case int: - return genericParseInt[V](str, 0, func(i int64) V { return assertValueType[V, int](int(i)) }, defaultValue...) + result, err := strconv.ParseInt(str, 10, 0) + if err != nil { + return v, fmt.Errorf("failed to parse int: %w", err) + } + return any(int(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case int8: - return genericParseInt[V](str, 8, func(i int64) V { return assertValueType[V, int8](int8(i)) }, defaultValue...) + result, err := strconv.ParseInt(str, 10, 8) + if err != nil { + return v, fmt.Errorf("failed to parse int8: %w", err) + } + return any(int8(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case int16: - return genericParseInt[V](str, 16, func(i int64) V { return assertValueType[V, int16](int16(i)) }, defaultValue...) + result, err := strconv.ParseInt(str, 10, 16) + if err != nil { + return v, fmt.Errorf("failed to parse int16: %w", err) + } + return any(int16(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case int32: - return genericParseInt[V](str, 32, func(i int64) V { return assertValueType[V, int32](int32(i)) }, defaultValue...) + result, err := strconv.ParseInt(str, 10, 32) + if err != nil { + return v, fmt.Errorf("failed to parse int32: %w", err) + } + return any(int32(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case int64: - return genericParseInt[V](str, 64, func(i int64) V { return assertValueType[V, int64](i) }, defaultValue...) + result, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return v, fmt.Errorf("failed to parse int64: %w", err) + } + return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint: - return genericParseUint[V](str, 0, func(i uint64) V { return assertValueType[V, uint](uint(i)) }, defaultValue...) + result, err := strconv.ParseUint(str, 10, 0) + if err != nil { + return v, fmt.Errorf("failed to parse uint: %w", err) + } + return any(uint(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint8: - return genericParseUint[V](str, 8, func(i uint64) V { return assertValueType[V, uint8](uint8(i)) }, defaultValue...) + result, err := strconv.ParseUint(str, 10, 8) + if err != nil { + return v, fmt.Errorf("failed to parse uint8: %w", err) + } + return any(uint8(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint16: - return genericParseUint[V](str, 16, func(i uint64) V { return assertValueType[V, uint16](uint16(i)) }, defaultValue...) + result, err := strconv.ParseUint(str, 10, 16) + if err != nil { + return v, fmt.Errorf("failed to parse uint16: %w", err) + } + return any(uint16(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint32: - return genericParseUint[V](str, 32, func(i uint64) V { return assertValueType[V, uint32](uint32(i)) }, defaultValue...) + result, err := strconv.ParseUint(str, 10, 32) + if err != nil { + return v, fmt.Errorf("failed to parse uint32: %w", err) + } + return any(uint32(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint64: - return genericParseUint[V](str, 64, func(i uint64) V { return assertValueType[V, uint64](i) }, defaultValue...) + result, err := strconv.ParseUint(str, 10, 64) + if err != nil { + return v, fmt.Errorf("failed to parse uint64: %w", err) + } + return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case float32: - return genericParseFloat[V](str, 32, func(i float64) V { return assertValueType[V, float32](float32(i)) }, defaultValue...) + result, err := strconv.ParseFloat(str, 32) + if err != nil { + return v, fmt.Errorf("failed to parse float32: %w", err) + } + return any(float32(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case float64: - return genericParseFloat[V](str, 64, func(i float64) V { return assertValueType[V, float64](i) }, defaultValue...) + result, err := strconv.ParseFloat(str, 64) + if err != nil { + return v, fmt.Errorf("failed to parse float64: %w", err) + } + return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case bool: - return genericParseBool[V](str, func(b bool) V { return assertValueType[V, bool](b) }, defaultValue...) + result, err := strconv.ParseBool(str) + if err != nil { + return v, fmt.Errorf("failed to parse bool: %w", err) + } + return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case string: - if str == "" && len(defaultValue) > 0 { - return defaultValue[0] + if str == "" { + return v, errParsedEmptyString } - return assertValueType[V, string](str) + return any(str).(V), nil //nolint:errcheck,forcetypeassert // not needed case []byte: - if str == "" && len(defaultValue) > 0 { - return defaultValue[0] + if str == "" { + return v, errParsedEmptyBytes } - return assertValueType[V, []byte]([]byte(str)) + return any([]byte(str)).(V), nil //nolint:errcheck,forcetypeassert // not needed default: - if len(defaultValue) > 0 { - return defaultValue[0] - } - return v + return v, errParsedType } } diff --git a/helpers_test.go b/helpers_test.go index 12d6b60fc1b..0e222a5ffe5 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -5,9 +5,12 @@ package fiber import ( + "math" + "strconv" "strings" "testing" "time" + "unsafe" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" @@ -511,18 +514,28 @@ func Benchmark_Utils_Unescape(b *testing.B) { func Test_Utils_Parse_Address(t *testing.T) { t.Parallel() + testCases := []struct { addr, host, port string }{ {addr: "[::1]:3000", host: "[::1]", port: "3000"}, {addr: "127.0.0.1:3000", host: "127.0.0.1", port: "3000"}, + {addr: "[::1]", host: "[::1]", port: ""}, + {addr: "2001:db8::1", host: "2001:db8::1", port: ""}, {addr: "/path/to/unix/socket", host: "/path/to/unix/socket", port: ""}, + {addr: "127.0.0.1", host: "127.0.0.1", port: ""}, + {addr: "localhost:8080", host: "localhost", port: "8080"}, + {addr: "example.com", host: "example.com", port: ""}, + {addr: "[fe80::1%lo0]:1234", host: "[fe80::1%lo0]", port: "1234"}, + {addr: "[fe80::1%lo0]", host: "[fe80::1%lo0]", port: ""}, + {addr: ":9090", host: "", port: "9090"}, + {addr: "", host: "", port: ""}, } for _, c := range testCases { host, port := parseAddr(c.addr) - require.Equal(t, c.host, host, "addr host") - require.Equal(t, c.port, port, "addr port") + require.Equal(t, c.host, host, "addr host: %q", c.addr) + require.Equal(t, c.port, port, "addr port: %q", c.addr) } } @@ -656,3 +669,643 @@ func Benchmark_SlashRecognition(b *testing.B) { require.True(b, result) }) } + +type testGenericParseTypeIntCase struct { + value int64 + bits int +} + +// go test -run Test_GenericParseTypeInts +func Test_GenericParseTypeInts(t *testing.T) { + t.Parallel() + ints := []testGenericParseTypeIntCase{ + { + value: 0, + bits: 8, + }, + { + value: 1, + bits: 8, + }, + { + value: 2, + bits: 8, + }, + { + value: 3, + bits: 8, + }, + { + value: 4, + bits: 8, + }, + { + value: -1, + bits: 8, + }, + { + value: math.MaxInt8, + bits: 8, + }, + { + value: math.MinInt8, + bits: 8, + }, + { + value: math.MaxInt16, + bits: 16, + }, + { + value: math.MinInt16, + bits: 16, + }, + { + value: math.MaxInt32, + bits: 32, + }, + { + value: math.MinInt32, + bits: 32, + }, + { + value: math.MaxInt64, + bits: 64, + }, + { + value: math.MinInt64, + bits: 64, + }, + } + + testGenericTypeInt[int8](t, "test_genericParseTypeInt8s", ints) + testGenericTypeInt[int16](t, "test_genericParseTypeInt16s", ints) + testGenericTypeInt[int32](t, "test_genericParseTypeInt32s", ints) + testGenericTypeInt[int64](t, "test_genericParseTypeInt64s", ints) + testGenericTypeInt[int](t, "test_genericParseTypeInts", ints) +} + +func testGenericTypeInt[V GenericTypeInteger](t *testing.T, name string, cases []testGenericParseTypeIntCase) { + t.Helper() + t.Run(name, func(t *testing.T) { + t.Parallel() + for _, test := range cases { + v, err := genericParseType[V](strconv.FormatInt(test.value, 10)) + if test.bits <= int(unsafe.Sizeof(V(0)))*8 { + require.NoError(t, err) + require.Equal(t, V(test.value), v) + } else { + require.ErrorIs(t, err, strconv.ErrRange) + } + } + testGenericParseError[V](t) + }) +} + +type testGenericParseTypeUintCase struct { + value uint64 + bits int +} + +// go test -run Test_GenericParseTypeUints +func Test_GenericParseTypeUints(t *testing.T) { + t.Parallel() + uints := []testGenericParseTypeUintCase{ + { + value: 0, + bits: 8, + }, + { + value: 1, + bits: 8, + }, + { + value: 2, + bits: 8, + }, + { + value: 3, + bits: 8, + }, + { + value: 4, + bits: 8, + }, + { + value: math.MaxUint8, + bits: 8, + }, + { + value: math.MaxUint16, + bits: 16, + }, + { + value: math.MaxUint32, + bits: 32, + }, + { + value: math.MaxUint64, + bits: 64, + }, + } + + testGenericTypeUint[uint8](t, "test_genericParseTypeUint8s", uints) + testGenericTypeUint[uint16](t, "test_genericParseTypeUint16s", uints) + testGenericTypeUint[uint32](t, "test_genericParseTypeUint32s", uints) + testGenericTypeUint[uint64](t, "test_genericParseTypeUint64s", uints) + testGenericTypeUint[uint](t, "test_genericParseTypeUints", uints) +} + +func testGenericTypeUint[V GenericTypeInteger](t *testing.T, name string, cases []testGenericParseTypeUintCase) { + t.Helper() + t.Run(name, func(t *testing.T) { + t.Parallel() + for _, test := range cases { + v, err := genericParseType[V](strconv.FormatUint(test.value, 10)) + if test.bits <= int(unsafe.Sizeof(V(0)))*8 { + require.NoError(t, err) + require.Equal(t, V(test.value), v) + } else { + require.ErrorIs(t, err, strconv.ErrRange) + } + } + testGenericParseError[V](t) + }) +} + +// go test -run Test_GenericParseTypeFloats +func Test_GenericParseTypeFloats(t *testing.T) { + t.Parallel() + + floats := []struct { + str string + value float64 + }{ + { + value: 3.1415, + str: "3.1415", + }, + { + value: 1.234, + str: "1.234", + }, + { + value: 2, + str: "2", + }, + { + value: 3, + str: "3", + }, + } + + t.Run("test_genericParseTypeFloat32s", func(t *testing.T) { + t.Parallel() + for _, test := range floats { + v, err := genericParseType[float32](test.str) + require.NoError(t, err) + require.InEpsilon(t, float32(test.value), v, epsilon) + } + testGenericParseError[float32](t) + }) + + t.Run("test_genericParseTypeFloat64s", func(t *testing.T) { + t.Parallel() + for _, test := range floats { + v, err := genericParseType[float64](test.str) + require.NoError(t, err) + require.InEpsilon(t, test.value, v, epsilon) + } + testGenericParseError[float64](t) + }) +} + +// go test -run Test_GenericParseTypeBytes +func Test_GenericParseTypeBytes(t *testing.T) { + t.Parallel() + + cases := []struct { + str string + err error + value []byte + }{ + { + value: []byte("alex"), + str: "alex", + }, + { + value: []byte("32.23"), + str: "32.23", + }, + { + value: []byte("john"), + str: "john", + }, + { + value: []byte(nil), + str: "", + err: errParsedEmptyBytes, + }, + } + + t.Run("test_genericParseTypeBytes", func(t *testing.T) { + t.Parallel() + for _, test := range cases { + v, err := genericParseType[[]byte](test.str) + if test.err == nil { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, test.err) + } + require.Equal(t, test.value, v) + } + }) +} + +// go test -run Test_GenericParseTypeString +func Test_GenericParseTypeString(t *testing.T) { + t.Parallel() + + tests := []string{"john", "doe", "hello", "fiber"} + + for _, test := range tests { + t.Run("test_genericParseTypeString", func(t *testing.T) { + t.Parallel() + v, err := genericParseType[string](test) + require.NoError(t, err) + require.Equal(t, test, v) + }) + } +} + +// go test -run Test_GenericParseTypeBoolean +func Test_GenericParseTypeBoolean(t *testing.T) { + t.Parallel() + + bools := []struct { + str string + value bool + }{ + { + str: "True", + value: true, + }, + { + str: "False", + value: false, + }, + { + str: "true", + value: true, + }, + { + str: "false", + value: false, + }, + } + + t.Run("test_genericParseTypeBoolean", func(t *testing.T) { + t.Parallel() + for _, test := range bools { + v, err := genericParseType[bool](test.str) + require.NoError(t, err) + if test.value { + require.True(t, v) + } else { + require.False(t, v) + } + } + testGenericParseError[bool](t) + }) +} + +func testGenericParseError[V GenericType](t *testing.T) { + t.Helper() + var expected V + v, err := genericParseType[V]("invalid-string") + require.Error(t, err) + require.Equal(t, expected, v) +} + +// go test -v -run=^$ -bench=Benchmark_GenericParseTypeInts -benchmem -count=4 +func Benchmark_GenericParseTypeInts(b *testing.B) { + ints := []testGenericParseTypeIntCase{ + { + value: 0, + bits: 8, + }, + { + value: 1, + bits: 8, + }, + { + value: 2, + bits: 8, + }, + { + value: 3, + bits: 8, + }, + { + value: 4, + bits: 8, + }, + { + value: -1, + bits: 8, + }, + { + value: math.MaxInt8, + bits: 8, + }, + { + value: math.MinInt8, + bits: 8, + }, + { + value: math.MaxInt16, + bits: 16, + }, + { + value: math.MinInt16, + bits: 16, + }, + { + value: math.MaxInt32, + bits: 32, + }, + { + value: math.MinInt32, + bits: 32, + }, + { + value: math.MaxInt64, + bits: 64, + }, + { + value: math.MinInt64, + bits: 64, + }, + } + for _, test := range ints { + benchGenericParseTypeInt[int8](b, "bench_genericParseTypeInt8s", test) + benchGenericParseTypeInt[int16](b, "bench_genericParseTypeInt16s", test) + benchGenericParseTypeInt[int32](b, "bench_genericParseTypeInt32s", test) + benchGenericParseTypeInt[int64](b, "bench_genericParseTypeInt64s", test) + benchGenericParseTypeInt[int](b, "bench_genericParseTypeInts", test) + } +} + +func benchGenericParseTypeInt[V GenericTypeInteger](b *testing.B, name string, test testGenericParseTypeIntCase) { + b.Helper() + b.Run(name, func(t *testing.B) { + var v V + var err error + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + v, err = genericParseType[V](strconv.FormatInt(test.value, 10)) + } + if test.bits <= int(unsafe.Sizeof(V(0)))*8 { + require.NoError(t, err) + require.Equal(t, V(test.value), v) + } else { + require.ErrorIs(t, err, strconv.ErrRange) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_GenericParseTypeUints -benchmem -count=4 +func Benchmark_GenericParseTypeUints(b *testing.B) { + uints := []struct { + value uint64 + bits int + }{ + { + value: 0, + bits: 8, + }, + { + value: 1, + bits: 8, + }, + { + value: 2, + bits: 8, + }, + { + value: 3, + bits: 8, + }, + { + value: 4, + bits: 8, + }, + { + value: math.MaxUint8, + bits: 8, + }, + { + value: math.MaxUint16, + bits: 16, + }, + { + value: math.MaxUint16, + bits: 16, + }, + { + value: math.MaxUint32, + bits: 32, + }, + { + value: math.MaxUint64, + bits: 64, + }, + } + + for _, test := range uints { + benchGenericParseTypeUInt[uint8](b, "benchmark_genericParseTypeUint8s", test) + benchGenericParseTypeUInt[uint16](b, "benchmark_genericParseTypeUint16s", test) + benchGenericParseTypeUInt[uint32](b, "benchmark_genericParseTypeUint32s", test) + benchGenericParseTypeUInt[uint64](b, "benchmark_genericParseTypeUint64s", test) + benchGenericParseTypeUInt[uint](b, "benchmark_genericParseTypeUints", test) + } +} + +func benchGenericParseTypeUInt[V GenericTypeInteger](b *testing.B, name string, test testGenericParseTypeUintCase) { + b.Helper() + b.Run(name, func(t *testing.B) { + var v V + var err error + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + v, err = genericParseType[V](strconv.FormatUint(test.value, 10)) + } + if test.bits <= int(unsafe.Sizeof(V(0)))*8 { + require.NoError(t, err) + require.Equal(t, V(test.value), v) + } else { + require.ErrorIs(t, err, strconv.ErrRange) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_GenericParseTypeFloats -benchmem -count=4 +func Benchmark_GenericParseTypeFloats(b *testing.B) { + floats := []struct { + str string + value float64 + }{ + { + value: 3.1415, + str: "3.1415", + }, + { + value: 1.234, + str: "1.234", + }, + { + value: 2, + str: "2", + }, + { + value: 3, + str: "3", + }, + } + + for _, test := range floats { + b.Run("benchmark_genericParseTypeFloat32s", func(t *testing.B) { + var v float32 + var err error + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + v, err = genericParseType[float32](test.str) + } + require.NoError(t, err) + require.InEpsilon(t, float32(test.value), v, epsilon) + }) + } + + for _, test := range floats { + b.Run("benchmark_genericParseTypeFloat64s", func(t *testing.B) { + var v float64 + var err error + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + v, err = genericParseType[float64](test.str) + } + require.NoError(t, err) + require.InEpsilon(t, test.value, v, epsilon) + }) + } +} + +// go test -v -run=^$ -bench=Benchmark_GenericParseTypeBytes -benchmem -count=4 +func Benchmark_GenericParseTypeBytes(b *testing.B) { + cases := []struct { + str string + err error + value []byte + }{ + { + value: []byte("alex"), + str: "alex", + }, + { + value: []byte("32.23"), + str: "32.23", + }, + { + value: []byte("john"), + str: "john", + }, + { + value: []byte(nil), + str: "", + err: errParsedEmptyBytes, + }, + } + + for _, test := range cases { + b.Run("benchmark_genericParseTypeBytes", func(b *testing.B) { + var v []byte + var err error + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + v, err = genericParseType[[]byte](test.str) + } + if test.err == nil { + require.NoError(b, err) + } else { + require.ErrorIs(b, err, test.err) + } + require.Equal(b, test.value, v) + }) + } +} + +// go test -v -run=^$ -bench=Benchmark_GenericParseTypeString -benchmem -count=4 +func Benchmark_GenericParseTypeString(b *testing.B) { + tests := []string{"john", "doe", "hello", "fiber"} + + for _, test := range tests { + b.Run("benchmark_genericParseTypeString", func(b *testing.B) { + var v string + var err error + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + v, err = genericParseType[string](test) + } + require.NoError(b, err) + require.Equal(b, test, v) + }) + } +} + +// go test -v -run=^$ -bench=Benchmark_GenericParseTypeBoolean -benchmem -count=4 +func Benchmark_GenericParseTypeBoolean(b *testing.B) { + bools := []struct { + str string + value bool + }{ + { + str: "True", + value: true, + }, + { + str: "False", + value: false, + }, + { + str: "true", + value: true, + }, + { + str: "false", + value: false, + }, + } + + for _, test := range bools { + b.Run("benchmark_genericParseTypeBoolean", func(b *testing.B) { + var v bool + var err error + b.ReportAllocs() + b.ResetTimer() + for n := 0; n < b.N; n++ { + v, err = genericParseType[bool](test.str) + } + require.NoError(b, err) + if test.value { + require.True(b, v) + } else { + require.False(b, v) + } + }) + } +} diff --git a/listen.go b/listen.go index d422695cce0..42ef4966c31 100644 --- a/listen.go +++ b/listen.go @@ -309,7 +309,7 @@ func (app *App) printMessages(cfg ListenConfig, ln net.Listener) { } } -// prepareListenData create an slice of ListenData +// prepareListenData creates a slice of ListenData func (*App) prepareListenData(addr string, isTLS bool, cfg ListenConfig) ListenData { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here host, port := parseAddr(addr) if host == "" { diff --git a/middleware/adaptor/adaptor.go b/middleware/adaptor/adaptor.go index 42b72101f9f..7fe1fc20431 100644 --- a/middleware/adaptor/adaptor.go +++ b/middleware/adaptor/adaptor.go @@ -32,8 +32,8 @@ func HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler { // HTTPHandler wraps net/http handler to fiber handler func HTTPHandler(h http.Handler) fiber.Handler { + handler := fasthttpadaptor.NewFastHTTPHandler(h) return func(c fiber.Ctx) error { - handler := fasthttpadaptor.NewFastHTTPHandler(h) handler(c.RequestCtx()) return nil } diff --git a/middleware/adaptor/adaptor_test.go b/middleware/adaptor/adaptor_test.go index 7bede0a8874..2340b7af45f 100644 --- a/middleware/adaptor/adaptor_test.go +++ b/middleware/adaptor/adaptor_test.go @@ -634,3 +634,34 @@ func Benchmark_FiberHandlerFunc_Parallel(b *testing.B) { }) } } + +func Benchmark_HTTPHandler(b *testing.B) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) //nolint:errcheck // not needed + }) + + var err error + app := fiber.New() + + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer func() { + app.ReleaseCtx(ctx) + }() + + b.ReportAllocs() + b.ResetTimer() + + fiberHandler := HTTPHandler(handler) + + for i := 0; i < b.N; i++ { + ctx.Request().Reset() + ctx.Response().Reset() + ctx.Request().SetRequestURI("/test") + ctx.Request().Header.SetMethod("GET") + + err = fiberHandler(ctx) + } + + require.NoError(b, err) +} diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 7f586cd9b3a..f0afa652b50 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -34,14 +34,16 @@ func Test_CSRF(t *testing.T) { h(ctx) // Without CSRF cookie - ctx.Request.Reset() + ctx.Request.Header.Reset() + ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Invalid CSRF token - ctx.Request.Reset() + ctx.Request.Header.Reset() + ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") @@ -49,7 +51,8 @@ func Test_CSRF(t *testing.T) { require.Equal(t, 403, ctx.Response.StatusCode()) // Valid CSRF token - ctx.Request.Reset() + ctx.Request.Header.Reset() + ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(method) h(ctx) @@ -193,12 +196,20 @@ func Test_CSRF_WithSession_Middleware(t *testing.T) { // Generate CSRF token and session_id ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) - csrfTokenParts := strings.Split(string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)), ";") - require.Greater(t, len(csrfTokenParts), 2) - csrfToken := strings.Split(csrfTokenParts[0], "=")[1] + + csrfCookie := fasthttp.AcquireCookie() + csrfCookie.SetKey(ConfigDefault.CookieName) + require.True(t, ctx.Response.Header.Cookie(csrfCookie)) + csrfToken := string(csrfCookie.Value()) require.NotEmpty(t, csrfToken) - sessionID := strings.Split(csrfTokenParts[1], "=")[1] + fasthttp.ReleaseCookie(csrfCookie) + + sessionCookie := fasthttp.AcquireCookie() + sessionCookie.SetKey("session_id") + require.True(t, ctx.Response.Header.Cookie(sessionCookie)) + sessionID := string(sessionCookie.Value()) require.NotEmpty(t, sessionID) + fasthttp.ReleaseCookie(sessionCookie) // Use the CSRF token and session_id ctx.Request.Reset() @@ -1087,7 +1098,8 @@ func Test_CSRF_DeleteToken(t *testing.T) { ctx := &fasthttp.RequestCtx{} // DeleteToken after token generation and remove the cookie - ctx.Request.Reset() + ctx.Request.Header.Reset() + ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.Set(HeaderName, "") handler := HandlerFromContext(app.AcquireCtx(ctx)) @@ -1105,7 +1117,8 @@ func Test_CSRF_DeleteToken(t *testing.T) { token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Delete the CSRF token - ctx.Request.Reset() + ctx.Request.Header.Reset() + ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) @@ -1118,7 +1131,8 @@ func Test_CSRF_DeleteToken(t *testing.T) { } h(ctx) - ctx.Request.Reset() + ctx.Request.Header.Reset() + ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) diff --git a/middleware/proxy/config.go b/middleware/proxy/config.go index 5edee17f908..45dab2dddf1 100644 --- a/middleware/proxy/config.go +++ b/middleware/proxy/config.go @@ -26,10 +26,10 @@ type Config struct { ModifyResponse fiber.Handler // tls config for the http client. - TlsConfig *tls.Config //nolint:stylecheck,revive // TODO: Rename to "TLSConfig" in v3 + TLSConfig *tls.Config // Client is custom client when client config is complex. - // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize, TlsConfig + // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize, TLSConfig // and DialDualStack will not be used if the client are set. Client *fasthttp.LBClient diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 2ac1e2cb444..07a689425a9 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -20,7 +20,7 @@ func Balancer(config Config) fiber.Handler { // Load balanced client lbc := &fasthttp.LBClient{} - // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig + // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TLSConfig // will not be used if the client are set. if config.Client == nil { // Set timeout @@ -44,7 +44,7 @@ func Balancer(config Config) fiber.Handler { ReadBufferSize: config.ReadBufferSize, WriteBufferSize: config.WriteBufferSize, - TLSConfig: config.TlsConfig, + TLSConfig: config.TLSConfig, DialDualStack: config.DialDualStack, } diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 532af09e905..d1f01aa783d 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -55,6 +55,31 @@ func createProxyTestServerIPv6(t *testing.T, handler fiber.Handler) (*fiber.App, return createProxyTestServer(t, handler, fiber.NetworkTCP6, "[::1]:0") } +func createRedirectServer(t *testing.T) string { + t.Helper() + app := fiber.New() + + var addr string + app.Get("/", func(c fiber.Ctx) error { + c.Location("http://" + addr + "/final") + return c.Status(fiber.StatusMovedPermanently).SendString("redirect") + }) + app.Get("/final", func(c fiber.Ctx) error { + return c.SendString("final") + }) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + ln.Close() //nolint:errcheck // It is fine to ignore the error here + }) + addr = ln.Addr().String() + + startServer(app, ln) + + return addr +} + // go test -run Test_Proxy_Empty_Host func Test_Proxy_Empty_Upstream_Servers(t *testing.T) { t.Parallel() @@ -152,7 +177,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { // disable certificate verification in Balancer app.Use(Balancer(Config{ Servers: []string{addr}, - TlsConfig: clientTLSConf, + TLSConfig: clientTLSConf, })) startServer(app, ln) @@ -501,9 +526,14 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) { // go test -race -run Test_Proxy_Do_WithRealURL func Test_Proxy_Do_WithRealURL(t *testing.T) { t.Parallel() + + _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { + return c.SendString("real url") + }) + app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { - return Do(c, "https://www.google.com") + return Do(c, "http://"+addr) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{ @@ -515,15 +545,17 @@ func Test_Proxy_Do_WithRealURL(t *testing.T) { require.Equal(t, "/test", resp.Request.URL.String()) body, err := io.ReadAll(resp.Body) require.NoError(t, err) - require.Contains(t, string(body), "https://www.google.com/") + require.Equal(t, "real url", string(body)) } // go test -race -run Test_Proxy_Do_WithRedirect func Test_Proxy_Do_WithRedirect(t *testing.T) { t.Parallel() + + addr := createRedirectServer(t) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { - return Do(c, "https://google.com") + return Do(c, "http://"+addr) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{ @@ -533,16 +565,18 @@ func Test_Proxy_Do_WithRedirect(t *testing.T) { require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) - require.Contains(t, string(body), "https://www.google.com/") - require.Equal(t, 301, resp.StatusCode) + require.Equal(t, "redirect", string(body)) + require.Equal(t, fiber.StatusMovedPermanently, resp.StatusCode) } // go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { t.Parallel() + + addr := createRedirectServer(t) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { - return DoRedirects(c, "http://google.com", 1) + return DoRedirects(c, "http://"+addr, 1) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{ @@ -550,8 +584,9 @@ func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { FailOnTimeout: true, }) require.NoError(t, err1) - _, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) require.NoError(t, err) + require.Equal(t, "final", string(body)) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "/test", resp.Request.URL.String()) } @@ -559,9 +594,11 @@ func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { // go test -race -run Test_Proxy_DoRedirects_TooManyRedirects func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) { t.Parallel() + + addr := createRedirectServer(t) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { - return DoRedirects(c, "http://google.com", 0) + return DoRedirects(c, "http://"+addr, 0) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), fiber.TestConfig{ diff --git a/middleware/timeout/timeout.go b/middleware/timeout/timeout.go index 127fff87232..5c7e9465270 100644 --- a/middleware/timeout/timeout.go +++ b/middleware/timeout/timeout.go @@ -19,12 +19,9 @@ func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler { // Create a context with the specified timeout; any operation exceeding // this deadline will be canceled automatically. - timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout) + timeoutContext, cancel := context.WithTimeout(ctx, timeout) defer cancel() - // Replace the default Fiber context with our timeout-bound context. - ctx.SetContext(timeoutContext) - // Run the handler and check for relevant errors. err := runHandler(ctx, h, tErrs) diff --git a/middleware/timeout/timeout_test.go b/middleware/timeout/timeout_test.go index 161296a71ad..cb58e9a9ada 100644 --- a/middleware/timeout/timeout_test.go +++ b/middleware/timeout/timeout_test.go @@ -41,7 +41,7 @@ func TestTimeout_Success(t *testing.T) { // Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit. app.Get("/fast", New(func(c fiber.Ctx) error { // Simulate some work - if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil { + if err := sleepWithContext(c, 10*time.Millisecond, context.DeadlineExceeded); err != nil { return err } return c.SendString("OK") @@ -60,7 +60,7 @@ func TestTimeout_Exceeded(t *testing.T) { // This handler sleeps 200ms, exceeding the 100ms limit. app.Get("/slow", New(func(c fiber.Ctx) error { - if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil { + if err := sleepWithContext(c, 200*time.Millisecond, context.DeadlineExceeded); err != nil { return err } return c.SendString("Should never get here") @@ -81,7 +81,7 @@ func TestTimeout_CustomError(t *testing.T) { app.Get("/custom", New(func(c fiber.Ctx) error { // Sleep might time out, or might return early. If the context is canceled, // we treat errCustomTimeout as a 'timeout-like' condition. - if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil { + if err := sleepWithContext(c, 200*time.Millisecond, errCustomTimeout); err != nil { return fmt.Errorf("wrapped: %w", err) } return c.SendString("Should never get here") diff --git a/router.go b/router.go index 3bf5bd8bd05..502f4f650f3 100644 --- a/router.go +++ b/router.go @@ -509,7 +509,7 @@ func (app *App) addRoute(method string, route *Route) { // This method is useful when you want to register routes dynamically after the app has started. // It is not recommended to use this method on production environments because rebuilding // the tree is performance-intensive and not thread-safe in runtime. Since building the tree -// is only done in the startupProcess of the app, this method does not makes sure that the +// is only done in the startupProcess of the app, this method does not make sure that the // routeTree is being safely changed, as it would add a great deal of overhead in the request. // Latest benchmark results showed a degradation from 82.79 ns/op to 94.48 ns/op and can be found in: // https://github.com/gofiber/fiber/issues/2769#issuecomment-2227385283