Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion field_describer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func addFieldDescriptions(d FieldDescriptionSet, v reflect.Value) {
t := v.Type()
for isPtr(t) && v.CanInterface() {
o := v.Interface()
if p, ok := o.(FieldDescriber); ok && !isPromotedMethod(o, "DescribeFields") {
if p, ok := o.(FieldDescriber); ok && !isPromotedMethod(v, "DescribeFields") {
p.DescribeFields(d)
}
t = t.Elem()
Expand Down
67 changes: 4 additions & 63 deletions flags.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package fangs

import (
"fmt"
"reflect"

"github.com/spf13/pflag"

"github.com/anchore/go-logger"
Expand All @@ -18,65 +15,9 @@ type FlagAdder interface {
func AddFlags(log logger.Logger, flags *pflag.FlagSet, structs ...any) {
flagSet := NewPFlagSet(log, flags)
for _, o := range structs {
addFlags(log, flagSet, o)
}
}

func addFlags(log logger.Logger, flags FlagSet, o any) {
v := reflect.ValueOf(o)
if !isPtr(v.Type()) {
panic(fmt.Sprintf("AddFlags must be called with pointers, got: %#v", o))
}

invokeAddFlags(log, flags, o)

v, t := base(v)

if isStruct(t) {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if !includeField(f) {
continue
}
v := v.Field(i)

if isPtr(v.Type()) {
// check if this is a pointer to a struct, if so, we need to initialize it
kind := v.Type().Elem().Kind()
if v.IsNil() && kind == reflect.Struct {
newV := reflect.New(v.Type().Elem())
if v.CanSet() {
v.Set(newV)
}
}
} else {
v = v.Addr()
}

if !v.CanInterface() {
continue
}

addFlags(log, flags, v.Interface())
}
}
}

func invokeAddFlags(_ logger.Logger, flags FlagSet, o any) {
// defer func() {
// // we may need to handle embedded structs having AddFlags methods called,
// // potentially adding flags with existing names. currently the isPromotedMethod
// // function works, but it is fairly brittle as there is no way through standard
// // go reflection to ascertain this information
// if err := recover(); err != nil {
// if log == nil {
// panic(err)
// }
// log.Debugf("got error while invoking AddFlags: %v", err)
// }
// }()

if o, ok := o.(FlagAdder); ok && !isPromotedMethod(o, "AddFlags") {
o.AddFlags(flags)
_ = InvokeAll(o, func(flagAdder FlagAdder) error {
flagAdder.AddFlags(flagSet)
return nil
}, InvokeAllCreateStructs, InvokeAllRequirePtr)
}
}
3 changes: 1 addition & 2 deletions flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"testing"

"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/anchore/go-logger/adapter/discard"
Expand Down Expand Up @@ -69,7 +68,7 @@ func Test_AddFlags_StructRefs(t *testing.T) {
AddFlags(discard.New(), flags, t1)

require.NotNil(t, t1.T2)
assert.Nil(t, t1.T2.Optional)
require.Nil(t, t1.T2.Optional)
}

type Sub2 struct {
Expand Down
202 changes: 202 additions & 0 deletions invoker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package fangs

import (
"fmt"
"reflect"
)

// InvokeAll recursively calls the invoker function with anything implementing the interface in the object graph.
// the type of the parameter to the invoker function is used to determine the interface, which must have exactly one
// method. InvokeAll will also avoid duplicate calls to methods on embedded structs where the method is inherited.
// InvokeAll optionally creates empty structs at every location in the object graph where a nil value exists that would
// point to a struct type; this may be used to ensure certain calls such as AddFlags and Summarize will always reference
// the same objects in memory.
func InvokeAll[T any](obj any, invoker func(T) error, opts ...func(*invokeAll)) error {
invokerFunc := reflect.ValueOf(invoker)
// get the target interface type
invokerFuncType := invokerFunc.Type()
interfaceType := invokerFuncType.In(0) // must have exactly 1 argument per func signature
iv := invokeAll{
interfaceType: interfaceType,
invokeFunc: invokerFunc,
funcName: funcName(interfaceType),
}
for _, opt := range opts {
opt(&iv)
}
return iv.invokeAll(reflect.ValueOf(obj))
}

// InvokeAllCreateStructs is an option to InvokeAll which causes nil structs pointers to be automatically populated with
// empty values
func InvokeAllCreateStructs(iv *invokeAll) {
iv.createStructs = true
}

// InvokeAllRequirePtr is an option to InvokeAll that indicates interface implementations must have a pointer receiver
func InvokeAllRequirePtr(iv *invokeAll) {
iv.requirePtr = true
}

func funcName(interfaceType reflect.Type) string {
if interfaceType.NumMethod() != 1 {
panic(fmt.Sprintf("provided interfaces must have exactly 1 method, got %v", interfaceType.NumMethod()))
}
m := interfaceType.Method(0)
return m.Name
}

type invokeAll struct {
interfaceType reflect.Type
invokeFunc reflect.Value
funcName string
createStructs bool
requirePtr bool
}

func (iv *invokeAll) invoke(v reflect.Value) error {
out := iv.invokeFunc.Call([]reflect.Value{v})[0] // must have exactly 1 error return value per func signature
if out.IsNil() {
return nil
}
return out.Interface().(error)
}

func (iv *invokeAll) invokeAll(v reflect.Value) error {
t := v.Type()

for isPtr(t) {
if v.IsNil() {
return nil
}

if v.CanInterface() {
if v.Type().Implements(iv.interfaceType) && !isPromotedMethod(v, iv.funcName) {
if err := iv.invoke(v); err != nil {
return err
}
}
}
t = t.Elem()
v = v.Elem()
}

// fail if implements the interface with something not using a pointer receiver
if v.Type().Implements(iv.interfaceType) && !isPromotedMethod(v, iv.funcName) {
if iv.requirePtr {
return fmt.Errorf("type implements interface without pointer reference: %v implements %v", v.Type(), iv.interfaceType)
}
if err := iv.invoke(v); err != nil {
return err
}
}

switch {
case isStruct(t):
return iv.invokeAllStruct(v)
case isSlice(t):
return iv.invokeAllSlice(v)
case isMap(t):
return iv.invokeAllMap(v)
}

return nil
}

// invokeAllStruct call recursively on struct fields
func (iv *invokeAll) invokeAllStruct(v reflect.Value) error {
t := v.Type()

for i := 0; i < v.NumField(); i++ {
f := t.Field(i)
if !includeField(f) {
continue
}

v := v.Field(i)

if isNil(v) {
// optionally create structs when there is only a nil pointer to it
if iv.createStructs && isStruct(v.Type().Elem()) {
fv := reflect.New(v.Type().Elem())
v.Set(fv) // set the newly created struct
v = fv
} else {
continue
}
}

for isPtr(v.Type()) {
v = v.Elem()
}

if !v.CanAddr() {
continue
}

if err := iv.invokeAll(v.Addr()); err != nil {
return err
}
}
return nil
}

// invokeAllSlice call recursively on slice items
func (iv *invokeAll) invokeAllSlice(v reflect.Value) error {
for i := 0; i < v.Len(); i++ {
v := v.Index(i)

if isNil(v) {
continue
}

for isPtr(v.Type()) {
v = v.Elem()
}

if !v.CanAddr() {
continue
}

if err := iv.invokeAll(v.Addr()); err != nil {
return err
}
}
return nil
}

// invokeAllMap call recursively on map values
func (iv *invokeAll) invokeAllMap(v reflect.Value) error {
mapV := v
i := v.MapRange()
for i.Next() {
v := i.Value()

if isNil(v) {
continue
}

for isPtr(v.Type()) {
v = v.Elem()
}

if !v.CanAddr() {
// unable to call .Addr() on struct map entries, so copy to a new instance and set on the map
if isStruct(v.Type()) {
newV := reflect.New(v.Type())
newV.Elem().Set(v)
if err := iv.invokeAll(newV); err != nil {
return err
}
mapV.SetMapIndex(i.Key(), newV.Elem())
}

continue
}

if err := iv.invokeAll(v.Addr()); err != nil {
return err
}
}
return nil
}
37 changes: 37 additions & 0 deletions invoker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package fangs

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_invoker(t *testing.T) {
calls := 0
s := badStructImpl{
incr: func() {
calls++
},
}

invoker := func(l PostLoader) error {
return l.PostLoad()
}

err := InvokeAll(s, invoker)
require.NoError(t, err)
require.Equal(t, 1, calls)

err = InvokeAll(s, invoker, InvokeAllRequirePtr)
require.Error(t, err)
require.Equal(t, 1, calls)
}

type badStructImpl struct {
incr func()
}

func (s badStructImpl) PostLoad() error {
s.incr()
return nil
}
Loading