Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,10 @@ func New(config ...Config) *App {
// Create Ctx pool
app.pool = sync.Pool{
New: func() any {
return app.newCtx()
if app.newCtxFunc != nil {
return app.newCtxFunc(app)
}
return NewDefaultCtx(app)
},
}

Expand Down Expand Up @@ -623,6 +626,15 @@ func New(config ...Config) *App {
return app
}

// NewWithCustomCtx creates a new Fiber instance and applies the
// provided function to generate a custom context type. It mirrors the behavior
// of calling `New()` followed by `app.setCtxFunc(fn)`.
func NewWithCustomCtx(newCtxFunc func(app *App) CustomCtx, config ...Config) *App {
app := New(config...)
app.setCtxFunc(newCtxFunc)
return app
}

// Adds an ip address to TrustProxyConfig.ranges or TrustProxyConfig.ips based on whether it is an IP range or not
func (app *App) handleTrustedProxy(ipAddress string) {
if strings.Contains(ipAddress, "/") {
Expand All @@ -642,13 +654,14 @@ func (app *App) handleTrustedProxy(ipAddress string) {
}
}

// NewCtxFunc allows to customize ctx methods as we want.
// Note: It doesn't allow adding new methods, only customizing exist methods.
func (app *App) NewCtxFunc(function func(app *App) CustomCtx) {
// setCtxFunc applies the given context factory to the app.
// It is used internally by NewWithCustomCtx. It doesn't allow adding new methods,
// only customizing existing ones.
func (app *App) setCtxFunc(function func(app *App) CustomCtx) {
app.newCtxFunc = function

if app.server != nil {
app.server.Handler = app.customRequestHandler
app.server.Handler = app.requestHandler
}
}

Expand Down Expand Up @@ -935,11 +948,7 @@ func (app *App) Config() Config {
func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
// prepare the server for the start
app.startupProcess()

if app.newCtxFunc != nil {
return app.customRequestHandler
}
return app.defaultRequestHandler
return app.requestHandler
}

// Stack returns the raw router stack.
Expand Down Expand Up @@ -1150,11 +1159,7 @@ func (app *App) init() *App {
}

// fasthttp server settings
if app.newCtxFunc != nil {
app.server.Handler = app.customRequestHandler
} else {
app.server.Handler = app.defaultRequestHandler
}
app.server.Handler = app.requestHandler
app.server.Name = app.config.ServerHeader
app.server.Concurrency = app.config.Concurrency
app.server.NoDefaultDate = app.config.DisableDefaultDate
Expand Down
109 changes: 109 additions & 0 deletions binder/binder_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package binder

import (
"mime/multipart"
"reflect"
"strconv"
"testing"

"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)

func Test_GetAndPutToThePool(t *testing.T) {
Expand All @@ -26,3 +30,108 @@ func Test_GetAndPutToThePool(t *testing.T) {
_ = GetFromThePool[*JSONBinding](&JSONBinderPool)
_ = GetFromThePool[*CBORBinding](&CBORBinderPool)
}

func Test_Binders_ErrorPaths(t *testing.T) {
t.Run("query binder invalid key", func(t *testing.T) {
b := &QueryBinding{}
req := fasthttp.AcquireRequest()
req.URI().SetQueryString("invalid[%3Dval&name=john")
defer fasthttp.ReleaseRequest(req)
err := b.Bind(req, &struct{}{})
require.Error(t, err)
require.Contains(t, err.Error(), "unmatched brackets")
})

t.Run("form binder invalid key", func(t *testing.T) {
b := &FormBinding{}
req := fasthttp.AcquireRequest()
req.SetBodyString("invalid[=val")
req.Header.SetContentType("application/x-www-form-urlencoded")
defer fasthttp.ReleaseRequest(req)
err := b.Bind(req, &struct{}{})
require.Error(t, err)
require.Contains(t, err.Error(), "unmatched brackets")
})

t.Run("form binder bad multipart", func(t *testing.T) {
b := &FormBinding{}
req := fasthttp.AcquireRequest()
req.Header.SetContentType(MIMEMultipartForm)
defer fasthttp.ReleaseRequest(req)
err := b.Bind(req, &struct{}{})
require.Error(t, err)
})
}

func Test_GetFieldCache_Panic(t *testing.T) {
t.Parallel()
require.Panics(t, func() { getFieldCache("unknown") })
}

func Test_parseToMap_defaultCase(t *testing.T) {
t.Parallel()
m := map[string]int{}
err := parseToMap(m, map[string][]string{"a": {"1"}})
require.NoError(t, err)
require.Empty(t, m)

m2 := map[string]string{}
err = parseToMap(m2, map[string][]string{"empty": {}})
require.NoError(t, err)
require.Equal(t, "", m2["empty"])
}

func Test_parse_function_maps(t *testing.T) {
t.Parallel()

m := map[string][]string{}
err := parse("query", &m, map[string][]string{"a": {"b"}})
require.NoError(t, err)
require.Equal(t, []string{"b"}, m["a"])

m2 := map[string]string{}
err = parse("query", &m2, map[string][]string{"a": {"b"}})
require.NoError(t, err)
require.Equal(t, "b", m2["a"])
}

func Test_SetParserDecoder_UnknownKeys(t *testing.T) {
SetParserDecoder(ParserConfig{IgnoreUnknownKeys: false})
defer SetParserDecoder(ParserConfig{IgnoreUnknownKeys: true, ZeroEmpty: true})
type user struct {
Name string `query:"name"`
}
data := map[string][]string{"name": {"john"}, "foo": {"bar"}}
err := parseToStruct("query", &user{}, data)
require.Error(t, err)
SetParserDecoder(ParserConfig{IgnoreUnknownKeys: true, ZeroEmpty: true})
}

func Test_SetParserDecoder_CustomConverter(t *testing.T) {
type myInt int
conv := func(s string) reflect.Value {
v, _ := strconv.Atoi(s) //nolint:errcheck // not needed
mi := myInt(v)
return reflect.ValueOf(mi)
}

SetParserDecoder(ParserConfig{ParserType: []ParserType{{CustomType: myInt(0), Converter: conv}}})
defer SetParserDecoder(ParserConfig{IgnoreUnknownKeys: true, ZeroEmpty: true})

type data struct {
V myInt `query:"v"`
}
d := new(data)
err := parse("query", d, map[string][]string{"v": {"5"}})
require.NoError(t, err)
require.Equal(t, myInt(5), d.V)
}

func Test_formatBindData_typeMismatch(t *testing.T) {
t.Parallel()
out := struct{}{}
files := map[string][]*multipart.FileHeader{}
err := formatBindData("query", out, files, "file", 123, false, false)
require.Error(t, err)
require.Equal(t, "unsupported value type: int", err.Error())
}
Loading