diff --git a/field_describer.go b/field_describer.go index ca7ae44..378a21d 100644 --- a/field_describer.go +++ b/field_describer.go @@ -71,7 +71,7 @@ func addFieldDescriptions(d FieldDescriptionSet, v reflect.Value) { t := v.Type() for isPtr(t) && v.CanInterface() { o := v.Interface() - if p, ok := o.(FieldDescriber); ok && !isPromotedMethod(o, "DescribeFields") { + if p, ok := o.(FieldDescriber); ok && !isPromotedMethod(v, "DescribeFields") { p.DescribeFields(d) } t = t.Elem() diff --git a/flags.go b/flags.go index ec06a89..1c5aa3e 100644 --- a/flags.go +++ b/flags.go @@ -1,9 +1,6 @@ package fangs import ( - "fmt" - "reflect" - "github.com/spf13/pflag" "github.com/anchore/go-logger" @@ -18,65 +15,9 @@ type FlagAdder interface { func AddFlags(log logger.Logger, flags *pflag.FlagSet, structs ...any) { flagSet := NewPFlagSet(log, flags) for _, o := range structs { - addFlags(log, flagSet, o) - } -} - -func addFlags(log logger.Logger, flags FlagSet, o any) { - v := reflect.ValueOf(o) - if !isPtr(v.Type()) { - panic(fmt.Sprintf("AddFlags must be called with pointers, got: %#v", o)) - } - - invokeAddFlags(log, flags, o) - - v, t := base(v) - - if isStruct(t) { - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - if !includeField(f) { - continue - } - v := v.Field(i) - - if isPtr(v.Type()) { - // check if this is a pointer to a struct, if so, we need to initialize it - kind := v.Type().Elem().Kind() - if v.IsNil() && kind == reflect.Struct { - newV := reflect.New(v.Type().Elem()) - if v.CanSet() { - v.Set(newV) - } - } - } else { - v = v.Addr() - } - - if !v.CanInterface() { - continue - } - - addFlags(log, flags, v.Interface()) - } - } -} - -func invokeAddFlags(_ logger.Logger, flags FlagSet, o any) { - // defer func() { - // // we may need to handle embedded structs having AddFlags methods called, - // // potentially adding flags with existing names. currently the isPromotedMethod - // // function works, but it is fairly brittle as there is no way through standard - // // go reflection to ascertain this information - // if err := recover(); err != nil { - // if log == nil { - // panic(err) - // } - // log.Debugf("got error while invoking AddFlags: %v", err) - // } - // }() - - if o, ok := o.(FlagAdder); ok && !isPromotedMethod(o, "AddFlags") { - o.AddFlags(flags) + _ = InvokeAll(o, func(flagAdder FlagAdder) error { + flagAdder.AddFlags(flagSet) + return nil + }, InvokeAllCreateStructs, InvokeAllRequirePtr) } } diff --git a/flags_test.go b/flags_test.go index 0173111..bf794fa 100644 --- a/flags_test.go +++ b/flags_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/spf13/pflag" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/anchore/go-logger/adapter/discard" @@ -69,7 +68,7 @@ func Test_AddFlags_StructRefs(t *testing.T) { AddFlags(discard.New(), flags, t1) require.NotNil(t, t1.T2) - assert.Nil(t, t1.T2.Optional) + require.Nil(t, t1.T2.Optional) } type Sub2 struct { diff --git a/invoker.go b/invoker.go new file mode 100644 index 0000000..a9442d3 --- /dev/null +++ b/invoker.go @@ -0,0 +1,202 @@ +package fangs + +import ( + "fmt" + "reflect" +) + +// InvokeAll recursively calls the invoker function with anything implementing the interface in the object graph. +// the type of the parameter to the invoker function is used to determine the interface, which must have exactly one +// method. InvokeAll will also avoid duplicate calls to methods on embedded structs where the method is inherited. +// InvokeAll optionally creates empty structs at every location in the object graph where a nil value exists that would +// point to a struct type; this may be used to ensure certain calls such as AddFlags and Summarize will always reference +// the same objects in memory. +func InvokeAll[T any](obj any, invoker func(T) error, opts ...func(*invokeAll)) error { + invokerFunc := reflect.ValueOf(invoker) + // get the target interface type + invokerFuncType := invokerFunc.Type() + interfaceType := invokerFuncType.In(0) // must have exactly 1 argument per func signature + iv := invokeAll{ + interfaceType: interfaceType, + invokeFunc: invokerFunc, + funcName: funcName(interfaceType), + } + for _, opt := range opts { + opt(&iv) + } + return iv.invokeAll(reflect.ValueOf(obj)) +} + +// InvokeAllCreateStructs is an option to InvokeAll which causes nil structs pointers to be automatically populated with +// empty values +func InvokeAllCreateStructs(iv *invokeAll) { + iv.createStructs = true +} + +// InvokeAllRequirePtr is an option to InvokeAll that indicates interface implementations must have a pointer receiver +func InvokeAllRequirePtr(iv *invokeAll) { + iv.requirePtr = true +} + +func funcName(interfaceType reflect.Type) string { + if interfaceType.NumMethod() != 1 { + panic(fmt.Sprintf("provided interfaces must have exactly 1 method, got %v", interfaceType.NumMethod())) + } + m := interfaceType.Method(0) + return m.Name +} + +type invokeAll struct { + interfaceType reflect.Type + invokeFunc reflect.Value + funcName string + createStructs bool + requirePtr bool +} + +func (iv *invokeAll) invoke(v reflect.Value) error { + out := iv.invokeFunc.Call([]reflect.Value{v})[0] // must have exactly 1 error return value per func signature + if out.IsNil() { + return nil + } + return out.Interface().(error) +} + +func (iv *invokeAll) invokeAll(v reflect.Value) error { + t := v.Type() + + for isPtr(t) { + if v.IsNil() { + return nil + } + + if v.CanInterface() { + if v.Type().Implements(iv.interfaceType) && !isPromotedMethod(v, iv.funcName) { + if err := iv.invoke(v); err != nil { + return err + } + } + } + t = t.Elem() + v = v.Elem() + } + + // fail if implements the interface with something not using a pointer receiver + if v.Type().Implements(iv.interfaceType) && !isPromotedMethod(v, iv.funcName) { + if iv.requirePtr { + return fmt.Errorf("type implements interface without pointer reference: %v implements %v", v.Type(), iv.interfaceType) + } + if err := iv.invoke(v); err != nil { + return err + } + } + + switch { + case isStruct(t): + return iv.invokeAllStruct(v) + case isSlice(t): + return iv.invokeAllSlice(v) + case isMap(t): + return iv.invokeAllMap(v) + } + + return nil +} + +// invokeAllStruct call recursively on struct fields +func (iv *invokeAll) invokeAllStruct(v reflect.Value) error { + t := v.Type() + + for i := 0; i < v.NumField(); i++ { + f := t.Field(i) + if !includeField(f) { + continue + } + + v := v.Field(i) + + if isNil(v) { + // optionally create structs when there is only a nil pointer to it + if iv.createStructs && isStruct(v.Type().Elem()) { + fv := reflect.New(v.Type().Elem()) + v.Set(fv) // set the newly created struct + v = fv + } else { + continue + } + } + + for isPtr(v.Type()) { + v = v.Elem() + } + + if !v.CanAddr() { + continue + } + + if err := iv.invokeAll(v.Addr()); err != nil { + return err + } + } + return nil +} + +// invokeAllSlice call recursively on slice items +func (iv *invokeAll) invokeAllSlice(v reflect.Value) error { + for i := 0; i < v.Len(); i++ { + v := v.Index(i) + + if isNil(v) { + continue + } + + for isPtr(v.Type()) { + v = v.Elem() + } + + if !v.CanAddr() { + continue + } + + if err := iv.invokeAll(v.Addr()); err != nil { + return err + } + } + return nil +} + +// invokeAllMap call recursively on map values +func (iv *invokeAll) invokeAllMap(v reflect.Value) error { + mapV := v + i := v.MapRange() + for i.Next() { + v := i.Value() + + if isNil(v) { + continue + } + + for isPtr(v.Type()) { + v = v.Elem() + } + + if !v.CanAddr() { + // unable to call .Addr() on struct map entries, so copy to a new instance and set on the map + if isStruct(v.Type()) { + newV := reflect.New(v.Type()) + newV.Elem().Set(v) + if err := iv.invokeAll(newV); err != nil { + return err + } + mapV.SetMapIndex(i.Key(), newV.Elem()) + } + + continue + } + + if err := iv.invokeAll(v.Addr()); err != nil { + return err + } + } + return nil +} diff --git a/invoker_test.go b/invoker_test.go new file mode 100644 index 0000000..f21d48a --- /dev/null +++ b/invoker_test.go @@ -0,0 +1,37 @@ +package fangs + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_invoker(t *testing.T) { + calls := 0 + s := badStructImpl{ + incr: func() { + calls++ + }, + } + + invoker := func(l PostLoader) error { + return l.PostLoad() + } + + err := InvokeAll(s, invoker) + require.NoError(t, err) + require.Equal(t, 1, calls) + + err = InvokeAll(s, invoker, InvokeAllRequirePtr) + require.Error(t, err) + require.Equal(t, 1, calls) +} + +type badStructImpl struct { + incr func() +} + +func (s badStructImpl) PostLoad() error { + s.incr() + return nil +} diff --git a/load.go b/load.go index 98d7f5b..6396864 100644 --- a/load.go +++ b/load.go @@ -77,7 +77,9 @@ func loadConfig(cfg Config, flags flagRefs, configurations ...any) error { } // Convert all populated config options to their internal application values ex: scope string => scopeOpt source.Scope - err = postLoad(reflect.ValueOf(configuration)) + err = InvokeAll(configuration, func(loader PostLoader) error { + return loader.PostLoad() + }, InvokeAllRequirePtr) if err != nil { return err } @@ -200,129 +202,6 @@ func readConfigFile(cfg Config, v *viper.Viper) error { return &viper.ConfigFileNotFoundError{} } -func postLoad(v reflect.Value) error { - t := v.Type() - - for isPtr(t) { - if v.IsNil() { - return nil - } - - if v.CanInterface() { - obj := v.Interface() - if p, ok := obj.(PostLoader); ok && !isPromotedMethod(obj, "PostLoad") { - if err := p.PostLoad(); err != nil { - return err - } - } - } - t = t.Elem() - v = v.Elem() - } - - switch { - case isStruct(t): - return postLoadStruct(v) - case isSlice(t): - return postLoadSlice(v) - case isMap(t): - return postLoadMap(v) - } - - return nil -} - -// postLoadStruct call recursively on struct fields -func postLoadStruct(v reflect.Value) error { - t := v.Type() - - for i := 0; i < v.NumField(); i++ { - f := t.Field(i) - if !includeField(f) { - continue - } - - v := v.Field(i) - - if isNil(v) { - continue - } - - for isPtr(v.Type()) { - v = v.Elem() - } - - if !v.CanAddr() { - continue - } - - if err := postLoad(v.Addr()); err != nil { - return err - } - } - return nil -} - -// postLoadSlice call recursively on slice items -func postLoadSlice(v reflect.Value) error { - for i := 0; i < v.Len(); i++ { - v := v.Index(i) - - if isNil(v) { - continue - } - - for isPtr(v.Type()) { - v = v.Elem() - } - - if !v.CanAddr() { - continue - } - - if err := postLoad(v.Addr()); err != nil { - return err - } - } - return nil -} - -// postLoadMap call recursively on map values -func postLoadMap(v reflect.Value) error { - mapV := v - i := v.MapRange() - for i.Next() { - v := i.Value() - - if isNil(v) { - continue - } - - for isPtr(v.Type()) { - v = v.Elem() - } - - if !v.CanAddr() { - // unable to call .Addr() on struct map entries, so copy to a new instance and set on the map - if isStruct(v.Type()) { - newV := reflect.New(v.Type()) - newV.Elem().Set(v) - if err := postLoad(newV); err != nil { - return err - } - mapV.SetMapIndex(i.Key(), newV.Elem()) - } - - continue - } - - if err := postLoad(v.Addr()); err != nil { - return err - } - } - return nil -} - type flagRefs map[uintptr]*pflag.Flag func commandFlagRefs(cmd *cobra.Command) flagRefs { @@ -388,8 +267,9 @@ func isNil(v reflect.Value) bool { switch v.Type().Kind() { case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Slice: return v.IsNil() + default: + return false } - return false } func isNotFoundErr(err error) bool { diff --git a/summarize_test.go b/summarize_test.go index 6d1a5a4..efaf6ed 100644 --- a/summarize_test.go +++ b/summarize_test.go @@ -300,6 +300,8 @@ func Test_SummarizeValuesWithPointers(t *testing.T) { cmd.AddCommand(subCmd) cmd.Flags().StringVar(&t1.TopString, "top-string", "", "top-string command description") + + // AddFlags needs to be called to bind to any flags, which affects summarize text AddFlags(cfg.Logger, subCmd.Flags(), t1) got := SummarizeCommand(cfg, subCmd, t1) diff --git a/utils.go b/utils.go index 11731a2..395472d 100644 --- a/utils.go +++ b/utils.go @@ -39,8 +39,7 @@ func fileExists(name string) bool { // as there is no way using standard go or reflection to identify this. the method currently // uses some undefined behavior of the go runtime that may change or may be unreliable when // used by structs created with reflection or if debug information is not present -func isPromotedMethod(o any, method string) bool { - v := reflect.ValueOf(o) +func isPromotedMethod(v reflect.Value, method string) bool { t := v.Type() m, ok := t.MethodByName(method) if !ok { diff --git a/utils_test.go b/utils_test.go index a41a831..fd710fd 100644 --- a/utils_test.go +++ b/utils_test.go @@ -10,7 +10,7 @@ import ( func Test_isPromotedMethod(t *testing.T) { s1 := &Sub2{} - require.True(t, !isPromotedMethod(s1, "AddFlags")) + require.True(t, !isPromotedMethod(reflect.ValueOf(s1), "AddFlags")) type Ty1 struct { Something string @@ -18,14 +18,14 @@ func Test_isPromotedMethod(t *testing.T) { } t1 := &Ty1{} - require.True(t, isPromotedMethod(t1, "AddFlags")) + require.True(t, isPromotedMethod(reflect.ValueOf(t1), "AddFlags")) type Ty2 struct { Ty1 } t2 := &Ty2{} - require.True(t, isPromotedMethod(t2, "AddFlags")) + require.True(t, isPromotedMethod(reflect.ValueOf(t2), "AddFlags")) // reflect-created structs do not include promoted methods tt1 := reflect.TypeOf(t1) @@ -36,5 +36,5 @@ func Test_isPromotedMethod(t *testing.T) { assert.False(t, ok) // not a promoted method because the method doesn't exist on the struct - require.True(t, !isPromotedMethod(t3, "AddFlags")) + require.True(t, !isPromotedMethod(reflect.ValueOf(t3), "AddFlags")) }