Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ package schema
import (
"errors"
"reflect"
"strconv"
"strings"
"sync"

utils "github.com/gofiber/utils/v2"
)

const maxParserIndex = 1000
Expand Down Expand Up @@ -52,18 +53,20 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
var struc *structInfo
var field *fieldInfo
var index64 int64
var err error
parts := make([]pathPart, 0)
path := make([]string, 0)
keys := strings.Split(p, ".")
for i := 0; i < len(keys); i++ {
var parts []pathPart
var path []string
for keyStart := 0; ; {
if t.Kind() != reflect.Struct {
return nil, errInvalidPath
}
if struc = c.get(t); struc == nil {
return nil, errInvalidPath
}
if field = struc.get(keys[i]); field == nil {
keyEnd, segment, err := nextPathSegment(p, keyStart)
if err != nil {
return nil, errInvalidPath
}
if field = struc.get(segment); field == nil {
return nil, errInvalidPath
}
// Valid field. Append index.
Expand All @@ -76,11 +79,15 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
// we don't need to force the struct's fields to appear in the path.
// So checking i+2 is not necessary anymore.
// We can skip this part if the type is multipart.FileHeader. It is another special case too.
i++
if i+1 > len(keys) {
keyStart = keyEnd + 1
if keyStart >= len(p) {
return nil, errInvalidPath
}
keyEnd, segment, err = nextPathSegment(p, keyStart)
if err != nil {
return nil, errInvalidPath
}
if index64, err = strconv.ParseInt(keys[i], 10, 0); err != nil {
if index64, err = utils.ParseInt(segment); err != nil {
return nil, errInvalidPath
}
if index64 > maxParserIndex {
Expand All @@ -91,7 +98,7 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
field: field,
index: int(index64),
})
path = make([]string, 0)
path = nil

// Get the next struct type, dropping ptrs.
if field.typ.Kind() == reflect.Ptr {
Expand All @@ -110,6 +117,14 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
} else {
t = field.typ
}

if keyEnd == len(p) {
break
}
keyStart = keyEnd + 1
if keyStart >= len(p) {
return nil, errInvalidPath
}
}
// Add the remaining.
parts = append(parts, pathPart{
Expand All @@ -120,6 +135,17 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
return parts, nil
}

func nextPathSegment(path string, start int) (int, string, error) {
end := start
for end < len(path) && path[end] != '.' {
end++
}
if start == end {
return 0, "", errInvalidPath
}
return end, path[start:end], nil
}

// get returns a cached structInfo, creating it if necessary.
func (c *cache) get(t reflect.Type) *structInfo {
c.l.RLock()
Expand All @@ -139,7 +165,11 @@ func (c *cache) create(t reflect.Type, parentAlias string) *structInfo {
info := &structInfo{}
var anonymousInfos []*structInfo
for i := 0; i < t.NumField(); i++ {
if f := c.createField(t.Field(i), parentAlias); f != nil {
structField := t.Field(i)
if structField.Anonymous && structField.Type.Kind() == reflect.Ptr {
info.anonymousPtrFields = append(info.anonymousPtrFields, i)
}
if f := c.createField(structField, parentAlias); f != nil {
info.fields = append(info.fields, f)
if ft := indirectType(f.typ); ft.Kind() == reflect.Struct && f.isAnonymous {
anonymousInfos = append(anonymousInfos, c.create(ft, f.canonicalAlias))
Expand All @@ -156,6 +186,13 @@ func (c *cache) create(t reflect.Type, parentAlias string) *structInfo {
}
}
}
info.fieldsByName = make(map[string]*fieldInfo, len(info.fields))
for _, field := range info.fields {
aliasKey := utils.ToLower(field.alias)
if _, exists := info.fieldsByName[aliasKey]; !exists {
info.fieldsByName[aliasKey] = field
}
}
Comment thread
ReneWerner87 marked this conversation as resolved.
return info
}

Expand Down Expand Up @@ -218,12 +255,18 @@ func (c *cache) converter(t reflect.Type) Converter {
// ----------------------------------------------------------------------------

type structInfo struct {
fields []*fieldInfo
fields []*fieldInfo
fieldsByName map[string]*fieldInfo
anonymousPtrFields []int
}

func (i *structInfo) get(alias string) *fieldInfo {
aliasKey := utils.ToLower(alias)
if field, ok := i.fieldsByName[aliasKey]; ok {
return field
}
for _, field := range i.fields {
if strings.EqualFold(field.alias, alias) {
if utils.ToLower(field.alias) == aliasKey {
return field
}
}
Expand Down Expand Up @@ -317,8 +360,8 @@ func (o tagOptions) Contains(option string) bool {

func (o tagOptions) getDefaultOptionValue() string {
for _, s := range o {
if strings.HasPrefix(s, "default:") {
return strings.SplitN(s, ":", 2)[1]
if value, ok := strings.CutPrefix(s, "default:"); ok {
return value
}
}
return ""
Expand Down
40 changes: 21 additions & 19 deletions converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package schema
import (
"reflect"
"strconv"

utils "github.com/gofiber/utils/v2"
)

type Converter func(string) reflect.Value
Expand Down Expand Up @@ -57,49 +59,49 @@ func convertBool(value string) reflect.Value {
}

func convertFloat32(value string) reflect.Value {
if v, err := strconv.ParseFloat(value, 32); err == nil {
return reflect.ValueOf(float32(v))
if v, err := utils.ParseFloat32(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertFloat64(value string) reflect.Value {
if v, err := strconv.ParseFloat(value, 64); err == nil {
if v, err := utils.ParseFloat64(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertInt(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 0); err == nil {
if v, err := utils.ParseInt(value); err == nil {
return reflect.ValueOf(int(v))
}
return invalidValue
}

func convertInt8(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 8); err == nil {
return reflect.ValueOf(int8(v))
if v, err := utils.ParseInt8(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertInt16(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 16); err == nil {
return reflect.ValueOf(int16(v))
if v, err := utils.ParseInt16(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertInt32(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 32); err == nil {
return reflect.ValueOf(int32(v))
if v, err := utils.ParseInt32(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertInt64(value string) reflect.Value {
if v, err := strconv.ParseInt(value, 10, 64); err == nil {
if v, err := utils.ParseInt(value); err == nil {
Comment thread
ReneWerner87 marked this conversation as resolved.
return reflect.ValueOf(v)
}
return invalidValue
Expand All @@ -110,35 +112,35 @@ func convertString(value string) reflect.Value {
}

func convertUint(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 0); err == nil {
if v, err := utils.ParseUint(value); err == nil {
return reflect.ValueOf(uint(v))
}
return invalidValue
}

func convertUint8(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 8); err == nil {
return reflect.ValueOf(uint8(v))
if v, err := utils.ParseUint8(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertUint16(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 16); err == nil {
return reflect.ValueOf(uint16(v))
if v, err := utils.ParseUint16(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertUint32(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 32); err == nil {
return reflect.ValueOf(uint32(v))
if v, err := utils.ParseUint32(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
}

func convertUint64(value string) reflect.Value {
if v, err := strconv.ParseUint(value, 10, 64); err == nil {
if v, err := utils.ParseUint(value); err == nil {
return reflect.ValueOf(v)
}
return invalidValue
Expand Down
19 changes: 12 additions & 7 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,9 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
v = v.Elem()
}

// alloc embedded structs
// Allocate embedded anonymous pointers required for promoted fields.
if v.Type().Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous {
field.Set(reflect.New(field.Type().Elem()))
}
}
d.ensureAnonymousPtrs(v)
}

v = v.FieldByName(name)
Expand Down Expand Up @@ -619,6 +614,16 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
return nil
}

func (d *Decoder) ensureAnonymousPtrs(v reflect.Value) {
info := d.cache.get(v.Type())
for _, idx := range info.anonymousPtrFields {
field := v.Field(idx)
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
}
}

func isTextUnmarshaler(v reflect.Value) unmarshaler {
// Create a new unmarshaller instance
m := unmarshaler{}
Expand Down
58 changes: 58 additions & 0 deletions decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,64 @@ func TestInvalidPathInCacheParsePath(t *testing.T) {
}
}

func TestCacheParsePathInvalidPathCases(t *testing.T) {
t.Parallel()

type item struct {
Value string `schema:"value"`
}
type payload struct {
N1 []item `schema:"n1"`
Name string `schema:"name"`
}

c := newCache()
payloadType := reflect.TypeOf(payload{})

tests := []struct {
name string
path string
typ reflect.Type
}{
{
name: "non struct type",
path: "name",
typ: reflect.TypeOf(0),
},
{
name: "unknown field",
path: "missing",
typ: payloadType,
},
{
name: "empty index segment",
path: "n1..value",
typ: payloadType,
},
{
name: "invalid index segment",
path: "n1.x.value",
typ: payloadType,
},
{
name: "trailing dot after index",
path: "n1.0.",
typ: payloadType,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

_, err := c.parsePath(tt.path, tt.typ)
if !errors.Is(err, errInvalidPath) {
t.Fatalf("expected errInvalidPath for %q, got: %v", tt.path, err)
}
})
}
}

// issue 32
func TestDecodeToTypedField(t *testing.T) {
type Aa bool
Expand Down
Loading