diff --git a/flag_def.go b/flag_def.go index 264993f..b576e88 100644 --- a/flag_def.go +++ b/flag_def.go @@ -2,10 +2,12 @@ package command import ( "cmp" + "encoding/csv" "errors" "fmt" "reflect" "strconv" + "strings" ) type ErrInvalidValue struct { @@ -104,6 +106,47 @@ func (fd *flagDef) setValue(sv string) error { } case reflect.String: fv.SetString(sv) + case reflect.Slice: + r := csv.NewReader(strings.NewReader(sv)) + r.LazyQuotes = true + r.TrimLeadingSpace = true + rec, err := r.Read() + if err != nil { + return &ErrInvalidValue{Cause: err, Value: sv, Flag: fd.Name} + } + + inValue := reflect.ValueOf(rec) + + targetType := fv.Type().Elem() + outSlice := reflect.MakeSlice(reflect.SliceOf(targetType), inValue.Len(), inValue.Len()) + for i, inElem := range rec { + var outElem interface{} + var err error + switch targetType.Kind() { + case reflect.String: + outElem = inElem + case reflect.Int: + outElem, err = strconv.Atoi(inElem) + case reflect.Float32: + if f64, parseErr := strconv.ParseFloat(inElem, 32); parseErr == nil { + outElem = float32(f64) + } else { + outElem = nil + err = parseErr + } + case reflect.Float64: + outElem, err = strconv.ParseFloat(inElem, 64) + case reflect.Bool: + outElem, err = strconv.ParseBool(inElem) + default: + return fmt.Errorf("%w: field kind is '%s'", errors.ErrUnsupported, fv.Kind()) + } + if err != nil { + return &ErrInvalidValue{Cause: err, Value: inElem, Flag: fd.Name} + } + outSlice.Index(i).Set(reflect.ValueOf(outElem).Convert(outSlice.Type().Elem())) + } + fv.Set(outSlice) default: return fmt.Errorf("%w: field kind is '%s'", errors.ErrUnsupported, fv.Kind()) } diff --git a/flag_set.go b/flag_set.go index f777baa..0918686 100644 --- a/flag_set.go +++ b/flag_set.go @@ -253,6 +253,17 @@ func (fs *flagSet) readFlagFromField(fieldValue reflect.Value, structField refle case reflect.String: fd.HasValue = true fd.DefaultValue = fieldValue.String() + case reflect.Slice: + fd.HasValue = true + var defaultValues []string + for i := 0; i < fieldValue.Len(); i++ { + defaultValues = append(defaultValues, fieldValue.Index(i).String()) + } + if defaultValues != nil { + fd.DefaultValue = strings.Join(defaultValues, ",") + } else { + fd.DefaultValue = "" + } default: // Unsupported flag field type return fmt.Errorf("unsupported field type: %s", fieldValue.Kind()) diff --git a/flag_set_test.go b/flag_set_test.go index 23a1be7..0d27021 100644 --- a/flag_set_test.go +++ b/flag_set_test.go @@ -609,6 +609,30 @@ func TestNewFlagSet(t *testing.T) { } } +func TestFlagSetWithArrays(t *testing.T) { + t.Parallel() + + config := &struct { + MyArray []string `flag:"true" ` + }{MyArray: []string{"v1", "v2"}} + + valueOfConfig := reflect.ValueOf(config) + fs, err := newFlagSet(nil, valueOfConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + if len(fs.flags) != 1 { + t.Fatalf("Expected 1 flag, got %d", len(fs.flags)) + } + + f := fs.flags[0] + With(t).Verify(f.Name).Will(EqualTo("my-array")).OrFail() + With(t).Verify(f.EnvVarName).Will(BeNil()).OrFail() + With(t).Verify(f.HasValue).Will(EqualTo(true)).OrFail() + With(t).Verify(f.ValueName).Will(BeNil()).OrFail() + With(t).Verify(f.Description).Will(BeNil()).OrFail() + With(t).Verify(f.Required).Will(BeNil()).OrFail() + With(t).Verify(f.DefaultValue).Will(EqualTo("v1,v2")).OrFail() +} + func TestFlagSetGetMergedFlagDefs(t *testing.T) { t.Parallel() type testCase struct { @@ -891,6 +915,51 @@ func TestFlagSetApply(t *testing.T) { expectedError string } testCases := map[string]testCase{ + "all types are supported from CLI": { + config: &struct { + String string `flag:"true"` + Int int `flag:"true"` + Float32 float32 `flag:"true"` + Float64 float64 `flag:"true"` + Bool bool `flag:"true"` + StringArray []string `flag:"true"` + IntArray []int `flag:"true"` + Float32Array []float32 `flag:"true"` + Float64Array []float64 `flag:"true"` + }{}, + args: []string{ + "--string", "s1", + "--int", "9", + "--float32", "1.2", + "--float64", "123.456", + "--bool", + "--string-array", `sa1,"s with space",sa3,,,"`, + "--int-array", `1,2,3,5,8`, + "--float32array", `1.2,3.4,5.6`, + "--float64array", `11.22,33.44,55.66`, + }, + expectedConfig: &struct { + String string `flag:"true"` + Int int `flag:"true"` + Float32 float32 `flag:"true"` + Float64 float64 `flag:"true"` + Bool bool `flag:"true"` + StringArray []string `flag:"true"` + IntArray []int `flag:"true"` + Float32Array []float32 `flag:"true"` + Float64Array []float64 `flag:"true"` + }{ + String: "s1", + Int: 9, + Float32: 1.2, + Float64: 123.456, + Bool: true, + StringArray: []string{"sa1", "s with space", "sa3", "", "", ""}, + IntArray: []int{1, 2, 3, 5, 8}, + Float32Array: []float32{1.2, 3.4, 5.6}, + Float64Array: []float64{11.22, 33.44, 55.66}, + }, + }, "CLI overrides environment variables": { config: &struct { F1 string `name:"my-field1"`