diff --git a/cache.go b/cache.go index 770fd8e..788a4f2 100644 --- a/cache.go +++ b/cache.go @@ -7,9 +7,10 @@ package schema import ( "errors" "reflect" - "strconv" "strings" "sync" + + utils "github.com/gofiber/utils/v2" ) const maxParserIndex = 1000 @@ -52,18 +53,20 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { var struc *structInfo var field *fieldInfo var index64 int64 - var err error - parts := make([]pathPart, 0) - path := make([]string, 0) - keys := strings.Split(p, ".") - for i := 0; i < len(keys); i++ { + var parts []pathPart + var path []string + for keyStart := 0; ; { if t.Kind() != reflect.Struct { return nil, errInvalidPath } if struc = c.get(t); struc == nil { return nil, errInvalidPath } - if field = struc.get(keys[i]); field == nil { + keyEnd, segment, err := nextPathSegment(p, keyStart) + if err != nil { + return nil, errInvalidPath + } + if field = struc.get(segment); field == nil { return nil, errInvalidPath } // Valid field. Append index. @@ -76,11 +79,15 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { // we don't need to force the struct's fields to appear in the path. // So checking i+2 is not necessary anymore. // We can skip this part if the type is multipart.FileHeader. It is another special case too. - i++ - if i+1 > len(keys) { + keyStart = keyEnd + 1 + if keyStart >= len(p) { + return nil, errInvalidPath + } + keyEnd, segment, err = nextPathSegment(p, keyStart) + if err != nil { return nil, errInvalidPath } - if index64, err = strconv.ParseInt(keys[i], 10, 0); err != nil { + if index64, err = utils.ParseInt(segment); err != nil { return nil, errInvalidPath } if index64 > maxParserIndex { @@ -91,7 +98,7 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { field: field, index: int(index64), }) - path = make([]string, 0) + path = nil // Get the next struct type, dropping ptrs. if field.typ.Kind() == reflect.Ptr { @@ -110,6 +117,14 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { } else { t = field.typ } + + if keyEnd == len(p) { + break + } + keyStart = keyEnd + 1 + if keyStart >= len(p) { + return nil, errInvalidPath + } } // Add the remaining. parts = append(parts, pathPart{ @@ -120,6 +135,17 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { return parts, nil } +func nextPathSegment(path string, start int) (int, string, error) { + end := start + for end < len(path) && path[end] != '.' { + end++ + } + if start == end { + return 0, "", errInvalidPath + } + return end, path[start:end], nil +} + // get returns a cached structInfo, creating it if necessary. func (c *cache) get(t reflect.Type) *structInfo { c.l.RLock() @@ -139,7 +165,11 @@ func (c *cache) create(t reflect.Type, parentAlias string) *structInfo { info := &structInfo{} var anonymousInfos []*structInfo for i := 0; i < t.NumField(); i++ { - if f := c.createField(t.Field(i), parentAlias); f != nil { + structField := t.Field(i) + if structField.Anonymous && structField.Type.Kind() == reflect.Ptr { + info.anonymousPtrFields = append(info.anonymousPtrFields, i) + } + if f := c.createField(structField, parentAlias); f != nil { info.fields = append(info.fields, f) if ft := indirectType(f.typ); ft.Kind() == reflect.Struct && f.isAnonymous { anonymousInfos = append(anonymousInfos, c.create(ft, f.canonicalAlias)) @@ -156,6 +186,13 @@ func (c *cache) create(t reflect.Type, parentAlias string) *structInfo { } } } + info.fieldsByName = make(map[string]*fieldInfo, len(info.fields)) + for _, field := range info.fields { + aliasKey := utils.ToLower(field.alias) + if _, exists := info.fieldsByName[aliasKey]; !exists { + info.fieldsByName[aliasKey] = field + } + } return info } @@ -218,12 +255,18 @@ func (c *cache) converter(t reflect.Type) Converter { // ---------------------------------------------------------------------------- type structInfo struct { - fields []*fieldInfo + fields []*fieldInfo + fieldsByName map[string]*fieldInfo + anonymousPtrFields []int } func (i *structInfo) get(alias string) *fieldInfo { + aliasKey := utils.ToLower(alias) + if field, ok := i.fieldsByName[aliasKey]; ok { + return field + } for _, field := range i.fields { - if strings.EqualFold(field.alias, alias) { + if utils.ToLower(field.alias) == aliasKey { return field } } @@ -317,8 +360,8 @@ func (o tagOptions) Contains(option string) bool { func (o tagOptions) getDefaultOptionValue() string { for _, s := range o { - if strings.HasPrefix(s, "default:") { - return strings.SplitN(s, ":", 2)[1] + if value, ok := strings.CutPrefix(s, "default:"); ok { + return value } } return "" diff --git a/converter.go b/converter.go index 4bae6df..6ac30cc 100644 --- a/converter.go +++ b/converter.go @@ -7,6 +7,8 @@ package schema import ( "reflect" "strconv" + + utils "github.com/gofiber/utils/v2" ) type Converter func(string) reflect.Value @@ -57,49 +59,49 @@ func convertBool(value string) reflect.Value { } func convertFloat32(value string) reflect.Value { - if v, err := strconv.ParseFloat(value, 32); err == nil { - return reflect.ValueOf(float32(v)) + if v, err := utils.ParseFloat32(value); err == nil { + return reflect.ValueOf(v) } return invalidValue } func convertFloat64(value string) reflect.Value { - if v, err := strconv.ParseFloat(value, 64); err == nil { + if v, err := utils.ParseFloat64(value); err == nil { return reflect.ValueOf(v) } return invalidValue } func convertInt(value string) reflect.Value { - if v, err := strconv.ParseInt(value, 10, 0); err == nil { + if v, err := utils.ParseInt(value); err == nil { return reflect.ValueOf(int(v)) } return invalidValue } func convertInt8(value string) reflect.Value { - if v, err := strconv.ParseInt(value, 10, 8); err == nil { - return reflect.ValueOf(int8(v)) + if v, err := utils.ParseInt8(value); err == nil { + return reflect.ValueOf(v) } return invalidValue } func convertInt16(value string) reflect.Value { - if v, err := strconv.ParseInt(value, 10, 16); err == nil { - return reflect.ValueOf(int16(v)) + if v, err := utils.ParseInt16(value); err == nil { + return reflect.ValueOf(v) } return invalidValue } func convertInt32(value string) reflect.Value { - if v, err := strconv.ParseInt(value, 10, 32); err == nil { - return reflect.ValueOf(int32(v)) + if v, err := utils.ParseInt32(value); err == nil { + return reflect.ValueOf(v) } return invalidValue } func convertInt64(value string) reflect.Value { - if v, err := strconv.ParseInt(value, 10, 64); err == nil { + if v, err := utils.ParseInt(value); err == nil { return reflect.ValueOf(v) } return invalidValue @@ -110,35 +112,35 @@ func convertString(value string) reflect.Value { } func convertUint(value string) reflect.Value { - if v, err := strconv.ParseUint(value, 10, 0); err == nil { + if v, err := utils.ParseUint(value); err == nil { return reflect.ValueOf(uint(v)) } return invalidValue } func convertUint8(value string) reflect.Value { - if v, err := strconv.ParseUint(value, 10, 8); err == nil { - return reflect.ValueOf(uint8(v)) + if v, err := utils.ParseUint8(value); err == nil { + return reflect.ValueOf(v) } return invalidValue } func convertUint16(value string) reflect.Value { - if v, err := strconv.ParseUint(value, 10, 16); err == nil { - return reflect.ValueOf(uint16(v)) + if v, err := utils.ParseUint16(value); err == nil { + return reflect.ValueOf(v) } return invalidValue } func convertUint32(value string) reflect.Value { - if v, err := strconv.ParseUint(value, 10, 32); err == nil { - return reflect.ValueOf(uint32(v)) + if v, err := utils.ParseUint32(value); err == nil { + return reflect.ValueOf(v) } return invalidValue } func convertUint64(value string) reflect.Value { - if v, err := strconv.ParseUint(value, 10, 64); err == nil { + if v, err := utils.ParseUint(value); err == nil { return reflect.ValueOf(v) } return invalidValue diff --git a/decoder.go b/decoder.go index 684456f..177e85a 100644 --- a/decoder.go +++ b/decoder.go @@ -410,14 +410,9 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values v = v.Elem() } - // alloc embedded structs + // Allocate embedded anonymous pointers required for promoted fields. if v.Type().Kind() == reflect.Struct { - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous { - field.Set(reflect.New(field.Type().Elem())) - } - } + d.ensureAnonymousPtrs(v) } v = v.FieldByName(name) @@ -619,6 +614,16 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values return nil } +func (d *Decoder) ensureAnonymousPtrs(v reflect.Value) { + info := d.cache.get(v.Type()) + for _, idx := range info.anonymousPtrFields { + field := v.Field(idx) + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + } +} + func isTextUnmarshaler(v reflect.Value) unmarshaler { // Create a new unmarshaller instance m := unmarshaler{} diff --git a/decoder_test.go b/decoder_test.go index 57d3c62..1e922b9 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -1260,6 +1260,64 @@ func TestInvalidPathInCacheParsePath(t *testing.T) { } } +func TestCacheParsePathInvalidPathCases(t *testing.T) { + t.Parallel() + + type item struct { + Value string `schema:"value"` + } + type payload struct { + N1 []item `schema:"n1"` + Name string `schema:"name"` + } + + c := newCache() + payloadType := reflect.TypeOf(payload{}) + + tests := []struct { + name string + path string + typ reflect.Type + }{ + { + name: "non struct type", + path: "name", + typ: reflect.TypeOf(0), + }, + { + name: "unknown field", + path: "missing", + typ: payloadType, + }, + { + name: "empty index segment", + path: "n1..value", + typ: payloadType, + }, + { + name: "invalid index segment", + path: "n1.x.value", + typ: payloadType, + }, + { + name: "trailing dot after index", + path: "n1.0.", + typ: payloadType, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, err := c.parsePath(tt.path, tt.typ) + if !errors.Is(err, errInvalidPath) { + t.Fatalf("expected errInvalidPath for %q, got: %v", tt.path, err) + } + }) + } +} + // issue 32 func TestDecodeToTypedField(t *testing.T) { type Aa bool diff --git a/encoder.go b/encoder.go index 7ea8ad5..967585b 100644 --- a/encoder.go +++ b/encoder.go @@ -5,6 +5,8 @@ import ( "fmt" "reflect" "strconv" + + utils "github.com/gofiber/utils/v2" ) type encoderFunc func(reflect.Value) string @@ -87,26 +89,28 @@ func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { errors := MultiError{} for i := 0; i < v.NumField(); i++ { + fieldValue := v.Field(i) + fieldType := fieldValue.Type() name, opts := fieldAlias(t.Field(i), e.cache.tag) if name == "-" { continue } // Encode struct pointer types if the field is a valid pointer and a struct. - if isValidStructPointer(v.Field(i)) && !e.hasCustomEncoder(v.Field(i).Type()) { - err := e.encode(v.Field(i).Elem(), dst) + if isValidStructPointer(fieldValue) && !e.hasCustomEncoder(fieldType) { + err := e.encode(fieldValue.Elem(), dst) if err != nil { - errors[v.Field(i).Elem().Type().String()] = err + errors[fieldValue.Elem().Type().String()] = err } continue } - encFunc := typeEncoder(v.Field(i).Type(), e.regenc) + encFunc := typeEncoder(fieldType, e.regenc) // Encode non-slice types and custom implementations immediately. if encFunc != nil { - value := encFunc(v.Field(i)) - if opts.Contains("omitempty") && isZero(v.Field(i)) { + value := encFunc(fieldValue) + if opts.Contains("omitempty") && isZero(fieldValue) { continue } @@ -114,31 +118,31 @@ func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { continue } - if v.Field(i).Type().Kind() == reflect.Struct { - err := e.encode(v.Field(i), dst) + if fieldType.Kind() == reflect.Struct { + err := e.encode(fieldValue, dst) if err != nil { - errors[v.Field(i).Type().String()] = err + errors[fieldType.String()] = err } continue } - if v.Field(i).Type().Kind() == reflect.Slice { - encFunc = typeEncoder(v.Field(i).Type().Elem(), e.regenc) + if fieldType.Kind() == reflect.Slice { + encFunc = typeEncoder(fieldType.Elem(), e.regenc) } if encFunc == nil { - errors[v.Field(i).Type().String()] = fmt.Errorf("schema: encoder not found for %v", v.Field(i)) + errors[fieldType.String()] = fmt.Errorf("schema: encoder not found for %v", fieldValue) continue } // Encode a slice. - if v.Field(i).Len() == 0 && opts.Contains("omitempty") { + if fieldValue.Len() == 0 && opts.Contains("omitempty") { continue } dst[name] = []string{} - for j := 0; j < v.Field(i).Len(); j++ { - dst[name] = append(dst[name], encFunc(v.Field(i).Index(j))) + for j := 0; j < fieldValue.Len(); j++ { + dst[name] = append(dst[name], encFunc(fieldValue.Index(j))) } } @@ -189,11 +193,11 @@ func encodeBool(v reflect.Value) string { } func encodeInt(v reflect.Value) string { - return strconv.FormatInt(int64(v.Int()), 10) + return utils.FormatInt(v.Int()) } func encodeUint(v reflect.Value) string { - return strconv.FormatUint(uint64(v.Uint()), 10) + return utils.FormatUint(v.Uint()) } func encodeFloat(v reflect.Value, bits int) string { diff --git a/go.mod b/go.mod index 9e8ac2b..231e350 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,7 @@ module github.com/gofiber/schema go 1.25 + +require github.com/gofiber/utils/v2 v2.0.0 + +require github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c6356b4 --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= +github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= +github.com/gofiber/utils/v2 v2.0.0 h1:SCC3rpsEDWupFSHtc0RKxg/BKgV0s1qKfZg9Jv6D0sM= +github.com/gofiber/utils/v2 v2.0.0/go.mod h1:xF9v89FfmbrYqI/bQUGN7gR8ZtXot2jxnZvmAUtiavE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shamaton/msgpack/v3 v3.0.0 h1:xl40uxWkSpwBCSTvS5wyXvJRsC6AcVcYeox9PspKiZg= +github.com/shamaton/msgpack/v3 v3.0.0/go.mod h1:DcQG8jrdrQCIxr3HlMYkiXdMhK+KfN2CitkyzsQV4uc= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=