From 88e5df1c9f6712b66b695a7690719e7f3de732ed Mon Sep 17 00:00:00 2001 From: Arik Kfir Date: Tue, 4 Jun 2024 00:52:02 +0300 Subject: [PATCH] refactor: major overhaul of API This refactor overhauls the API in preparation for v1.0.0. Following major changes have been applied: - Users now implement the Executor interface, instead of the Command interface. The Command object is now the container for a single command, created via the "New" factory. - Configuration schema is read from the given Executor instance's type. Any field in that struct can be a potential flag or args target. The API now allows customization of such flags via struct field tags (to be documented). Configuration can be nested in multi-level structs for grouping and reuse. - Commands can also be created via "MustNew" which panics if command creation fails. - Better command hierarchy with clear definition of which flags are inherited from parent commands to child commands. - Improved help & usage screens - Better code & file structure - Add golang CI linting --- .github/workflows/validation.yaml | 9 + command.go | 619 +++++++----------- command_test.go | 645 ++++++++++-------- configs_for_test.go | 22 - execute.go | 91 +++ execute_test.go | 115 ++++ flag_def.go | 166 +++++ flag_def_test.go | 311 +++++++++ flag_merged.go | 102 +++ flag_merged_test.go | 263 ++++++++ flag_set.go | 509 +++++++++++++++ flag_set_test.go | 1007 +++++++++++++++++++++++++++++ go.mod | 5 +- go.sum | 12 + util.go | 72 +-- util_test.go | 83 --- wrapping_writer.go | 89 +++ wrapping_writer_test.go | 281 ++++++++ 18 files changed, 3590 insertions(+), 811 deletions(-) delete mode 100644 configs_for_test.go create mode 100644 execute.go create mode 100644 execute_test.go create mode 100644 flag_def.go create mode 100644 flag_def_test.go create mode 100644 flag_merged.go create mode 100644 flag_merged_test.go create mode 100644 flag_set.go create mode 100644 flag_set_test.go create mode 100644 wrapping_writer.go create mode 100644 wrapping_writer_test.go diff --git a/.github/workflows/validation.yaml b/.github/workflows/validation.yaml index f1ce45b..c1f9536 100644 --- a/.github/workflows/validation.yaml +++ b/.github/workflows/validation.yaml @@ -14,6 +14,10 @@ jobs: fail-fast: false matrix: go-version: [ '1.21', '1.22' ] + permissions: + contents: read + pull-requests: read + checks: write steps: - name: Checkout uses: actions/checkout@v4 @@ -26,5 +30,10 @@ jobs: - name: Download dependencies run: go mod download -x + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: latest + - name: Test run: go test ./... diff --git a/command.go b/command.go index 6a6674a..d15576a 100644 --- a/command.go +++ b/command.go @@ -1,474 +1,291 @@ package command import ( - "bytes" - "context" - "flag" + "errors" "fmt" "io" "reflect" - "regexp" - "slices" - "strconv" "strings" ) -//goland:noinspection GoUnusedGlobalVariable -var Version = "0.0.0-unknown" - var ( - tokenRE = regexp.MustCompile(`^([^=]+)=(.*)$`) + ErrInvalidCommand = errors.New("invalid command") + ErrCommandAlreadyHasParent = errors.New("command already has a parent") ) -func New(parent *Command, spec Spec) *Command { - if parent != nil && !parent.createdByNewCommand { - panic("illegal parent was specified - was the parent created by 'command.New(...)'?") - } - cmd := &Command{ - Name: spec.Name, - ShortDescription: spec.ShortDescription, - LongDescription: spec.LongDescription, - Config: spec.Config, - PreSubCommandRun: spec.OnSubCommandRun, - Run: spec.Run, - builtinConfig: &BuiltinConfig{Help: false}, - parent: parent, - createdByNewCommand: true, - } - if cmd.parent != nil { - cmd.parent.subCommands = append(cmd.parent.subCommands, cmd) - } - return cmd -} - -type Spec struct { - Name string - ShortDescription string - LongDescription string - Config any - OnSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error - Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error +// HelpConfig is a configuration added to every executed command, for automatic help screen generation. +type HelpConfig struct { + Help bool `inherited:"true" desc:"Show this help screen and exit."` } +// Command is a command instance, created by [New] and can be composed with more Command instances to form a CLI command +// hierarchy. type Command struct { - Name string - ShortDescription string - LongDescription string - Config any - PreSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error - Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error - builtinConfig any - parent *Command - subCommands []*Command - createdByNewCommand bool - envVarsMapping map[string]reflect.Value - flagSet *flag.FlagSet - flagArgNames map[string]string - requiredFlags []string - positionalArgsTarget *[]string -} - -type BuiltinConfig struct { - Help bool `desc:"Show help about how to use this command"` -} - -type UsagePrinter interface { - PrintShortUsage(w io.Writer) - PrintFullUsage(w io.Writer) -} - -func (c *Command) PrintShortUsage(w io.Writer) { - c.printCommandUsage(w, true) + name string + shortDescription string + longDescription string + executor Executor + flags *flagSet + parent *Command + subCommands []*Command + HelpConfig *HelpConfig } -func (c *Command) PrintFullUsage(w io.Writer) { - c.printCommandUsage(w, false) +// MustNew creates a new command using [New], but will panic if it returns an error. +// +//goland:noinspection GoUnusedExportedFunction +func MustNew(name, shortDescription, longDescription string, executor Executor, subCommands ...*Command) *Command { + cmd, err := New(name, shortDescription, longDescription, executor, subCommands...) + if err != nil { + panic(err) + } + return cmd } -func (c *Command) initializeFlagSet() error { - if c.flagArgNames == nil { - c.flagArgNames = make(map[string]string) +// New creates a new command with the given name, short & long descriptions, and the given executor. The executor object +// is also scanned for configuration structs via reflection. +func New(name, shortDescription, longDescription string, executor Executor, subCommands ...*Command) (*Command, error) { + if name == "" { + return nil, fmt.Errorf("%w: empty name", ErrInvalidCommand) + } else if shortDescription == "" { + return nil, fmt.Errorf("%w: empty short description", ErrInvalidCommand) + } else if executor == nil { + return nil, fmt.Errorf("%w: nil executor", ErrInvalidCommand) } - // Create a flag set - name := c.Name - for parent := c.parent; parent != nil; parent = parent.parent { - name = parent.Name + " " + name - } - c.flagSet = flag.NewFlagSet(name, flag.ContinueOnError) - c.flagSet.SetOutput(io.Discard) - if err := c.initializeFlagSetFromStruct(reflect.ValueOf(c.builtinConfig).Elem()); err != nil { - return fmt.Errorf("failed to process builtin configuration fields: %w", err) + // Create the command instance + cmd := &Command{ + name: name, + shortDescription: shortDescription, + longDescription: longDescription, + executor: executor, + HelpConfig: &HelpConfig{}, } - // If this command has no configuration stop here - if c.Config == nil { - return nil + // Set nil parent + if err := cmd.setParent(nil); err != nil { + return nil, fmt.Errorf("failed creating command '%s': %w", name, err) } - // Verify the configuration field's type: must be a pointer to a struct, nothing else is accepted - valueOfCfg := reflect.ValueOf(c.Config) - if valueOfCfg.Kind() == reflect.Struct { - return fmt.Errorf("field 'Config' in command '%s' must not be a struct, but a pointer to a struct: %+v", c.Name, c.Config) - } else if valueOfCfg.Kind() != reflect.Ptr { - return fmt.Errorf("field 'Config' in command '%s' is not a pointer: %+v", c.Name, c.Config) - } else if valueOfCfg.IsNil() { - return nil - } - valueOfCfgStruct := valueOfCfg.Elem() - if valueOfCfgStruct.Kind() != reflect.Struct { - return fmt.Errorf("field 'Config' in command '%s' is not a pointer to struct: %+v", c.Name, c.Config) + // Add sub-commands + for _, subCmd := range subCommands { + if err := cmd.AddSubCommand(subCmd); err != nil { + return nil, fmt.Errorf("%w: failed adding sub-command '%s' to '%s': %w", ErrInvalidCommand, subCmd.name, name, err) + } } - // Process the struct fields as flags - if err := c.initializeFlagSetFromStruct(valueOfCfgStruct); err != nil { - return fmt.Errorf("failed to process configuration fields: %w", err) + return cmd, nil +} + +// setParent updates the parent command of this command. +func (c *Command) setParent(parent *Command) error { + + // Determine the parent flagSet, if any + var parentFlags *flagSet + if parent != nil { + parentFlags = parent.flags + } else if fs, err := newFlagSet(nil, reflect.ValueOf(c).Elem().FieldByName("HelpConfig")); err != nil { + return fmt.Errorf("failed creating Help flag set: %w", err) + } else { + parentFlags = fs } + // Create the flag-set + if fs, err := newFlagSet(parentFlags, reflect.ValueOf(c.executor)); err != nil { + return fmt.Errorf("failed creating flag-set for command '%s': %w", c.name, err) + } else { + c.parent = parent + c.flags = fs + } return nil } -func (c *Command) initializeFlagSetFromStruct(valueOfCfgStruct reflect.Value) error { - for i := 0; i < valueOfCfgStruct.NumField(); i++ { - structField := valueOfCfgStruct.Type().Field(i) - fieldName := structField.Name - fieldValue := valueOfCfgStruct.Field(i) - if !fieldValue.CanAddr() { - return fmt.Errorf("field '%s' is not addressable", fieldName) - } else if !fieldValue.CanSet() { - return fmt.Errorf("field '%s' is not settable", fieldName) - } - - flagName := fieldNameToFlagName(fieldName) - description := structField.Tag.Get("desc") - - envVarName := fieldNameToEnvVarName(fieldName) - if c.envVarsMapping == nil { - c.envVarsMapping = make(map[string]reflect.Value) - } - - // TODO: support commas inside token values (currently split will incorrectly split them) - targetPtr := fieldValue.Addr().Interface() - positionalsField := false - for _, token := range strings.Split(structField.Tag.Get("flag"), ",") { - if token == "ignore" { - if slices.Contains(c.requiredFlags, flagName) { - return fmt.Errorf("field '%s' cannot be both required and ignored", fieldName) - } else { - continue - } - } else if token == "required" { - c.requiredFlags = append(c.requiredFlags, flagName) - } else if token == "args" { - if structField.Type.ConvertibleTo(reflect.TypeOf([]string{})) == false { - return fmt.Errorf("field '%s' has 'args' tag but is not of type '[]string'", fieldName) - } else if c.positionalArgsTarget != nil { - return fmt.Errorf("multiple fields with 'args' tag found in command '%s'", c.Name) - } else { - c.positionalArgsTarget = targetPtr.(*[]string) - positionalsField = true - } - continue - } else if keyValue := tokenRE.FindStringSubmatch(token); keyValue != nil { - key := keyValue[1] - value := keyValue[2] - switch key { - case "valueName": - c.flagArgNames[flagName] = value - default: - return fmt.Errorf("unsupported config tag key: %s", key) - } - } - } - if positionalsField == true { - continue - } - - switch fieldValue.Kind() { - case reflect.Bool: - c.flagSet.BoolVar(targetPtr.(*bool), flagName, fieldValue.Bool(), description) - c.flagArgNames[flagName] = "" // to disable value name in usage page - c.envVarsMapping[envVarName] = fieldValue - case reflect.Int: - c.flagSet.IntVar(targetPtr.(*int), flagName, int(fieldValue.Int()), description) - c.envVarsMapping[envVarName] = fieldValue - case reflect.Uint: - c.flagSet.UintVar(targetPtr.(*uint), flagName, uint(fieldValue.Uint()), description) - c.envVarsMapping[envVarName] = fieldValue - case reflect.Float64: - c.flagSet.Float64Var(targetPtr.(*float64), flagName, fieldValue.Float(), description) - c.envVarsMapping[envVarName] = fieldValue - case reflect.String: - c.flagSet.StringVar(targetPtr.(*string), flagName, fieldValue.String(), description) - c.envVarsMapping[envVarName] = fieldValue - case reflect.Struct: - if err := c.initializeFlagSetFromStruct(fieldValue); err != nil { - return fmt.Errorf("failed adding flags for field '%s': %w", fieldName, err) - } - default: - panic(fmt.Sprintf("unsupported configuration field type: %s\n", fieldValue.Kind())) - } +// AddSubCommand will add the given command as a sub-command of this command. An error is returned if the given command +// already has another parent. +func (c *Command) AddSubCommand(cmd *Command) error { + if cmd.parent != nil { + return fmt.Errorf("%w: %s", ErrCommandAlreadyHasParent, cmd.parent.name) + } + c.subCommands = append(c.subCommands, cmd) + if err := cmd.setParent(c); err != nil { + return fmt.Errorf("failed setting parent for command '%s': %w", cmd.name, err) } - return nil } -func (c *Command) applyEnvironmentVariables(envVars map[string]string) error { - if c.envVarsMapping != nil { - for envVarName, fieldValue := range c.envVarsMapping { - targetPtr := fieldValue.Addr().Interface() - switch fieldValue.Kind() { - case reflect.Bool: - if stringValue, found := envVars[envVarName]; found { - if boolValue, err := strconv.ParseBool(stringValue); err != nil { - return fmt.Errorf("failed to parse environment variable '%s': %w", envVarName, err) - } else { - *targetPtr.(*bool) = boolValue - } - } - case reflect.Int: - if stringValue, found := envVars[envVarName]; found { - if intValue, err := strconv.ParseInt(stringValue, 10, 0); err != nil { - return fmt.Errorf("failed to parse environment variable '%s': %w", envVarName, err) - } else { - *targetPtr.(*int) = int(intValue) - } - } - case reflect.Uint: - if stringValue, found := envVars[envVarName]; found { - if uintValue, err := strconv.ParseUint(stringValue, 10, 0); err != nil { - return fmt.Errorf("failed to parse environment variable '%s': %w", envVarName, err) - } else { - *targetPtr.(*uint) = uint(uintValue) - } - } - case reflect.Float64: - if stringValue, found := envVars[envVarName]; found { - if float64Value, err := strconv.ParseFloat(stringValue, 0); err != nil { - return fmt.Errorf("failed to parse environment variable '%s': %w", envVarName, err) - } else { - *targetPtr.(*float64) = float64Value - } - } - case reflect.String: - if value, found := envVars[envVarName]; found { - *targetPtr.(*string) = value +// inferCommandAndArgs takes the given CLI arguments, and splits them into flags, positional arguments, but most +// importantly, understands which command the user is trying to invoke. This is done by comparing given positional +// arguments to the current command hierarchy, and removing positional arguments that denote sub-commands. +// +// For example, assuming the following command line is given: +// +// cmd1 -flag1 sub1 something -flag2=1 sub2 -- sub3 -flag3 a b c +// +// And the command hierarchy is: cmd1 -> sub1 -> sub2 -> sub3 +// +// The returned values would be: +// - flags: [-flag1, -flag2=1]: no "-flag3" because it's after the "--" separator +// - positionals: [something, sub3, a, b, c]: no "cmd1", "sub1" and "sub2" as they are commands in the hierarchy +// - command: sub2 (since it's the last valid command before the "--" which signals positional args only) +func (c *Command) inferCommandAndArgs(args []string) (flags, positionals []string, current *Command) { + current = c + onlyPositionalArgs := false + for _, arg := range args { + if onlyPositionalArgs { + positionals = append(positionals, arg) + } else if arg == "--" { + onlyPositionalArgs = true + } else if strings.HasPrefix(arg, "-") { + flags = append(flags, arg) + } else { + found := false + for _, subCmd := range current.subCommands { + if subCmd.name == arg { + current = subCmd + found = true + break } - default: - panic(fmt.Sprintf("unsupported configuration field type: %s\n", fieldValue.Kind())) + } + if !found { + positionals = append(positionals, arg) } } } - return nil -} - -func (c *Command) applyCLIArguments(args []string) error { - - // Update config with CLI arguments - if err := c.flagSet.Parse(args); err != nil { - return fmt.Errorf("failed to apply CLI arguments: %w", err) - } - - return nil + return } -func (c *Command) validateRequiredFlagsWereProvided(envVars map[string]string) error { - var missingRequiredFlags []string - copy(missingRequiredFlags, c.requiredFlags) - c.flagSet.Visit(func(f *flag.Flag) { - missingRequiredFlags = slices.DeleteFunc(missingRequiredFlags, func(requiredFlagName string) bool { - return requiredFlagName == f.Name - }) - }) - for envVarName := range c.envVarsMapping { - if _, found := envVars[envVarName]; found { - envVarFlagName := environmentVariableToFlagName(envVarName) - missingRequiredFlags = slices.DeleteFunc(missingRequiredFlags, func(requiredFlagName string) bool { - return requiredFlagName == envVarFlagName - }) +// getFullName returns the names of all commands in this command's hierarchy, starting from the root, all the way to +// this command. +// +// For example, assuming the following command hierarchy: +// +// cmd1 -> sub1 -> sub2 -> sub3 +// +// This function would return "cmd1 sub1" for the "sub1" command. +func (c *Command) getFullName() string { + var fullName string + for cmd := c; cmd != nil; cmd = cmd.parent { + if fullName != "" { + fullName = " " + fullName } + fullName = cmd.name + fullName } - if len(missingRequiredFlags) > 0 { - return fmt.Errorf("these required flags have not set via either CLI nor environment variables: %v", missingRequiredFlags) - } - return nil + return fullName } -func (c *Command) configure(envVars map[string]string, args []string) error { - - // Initialize the flagSet for the chosen command - if err := c.initializeFlagSet(); err != nil { - panic(fmt.Sprintf("failed to initialize flag set for command '%s': %v", c.Name, err)) +// getChain returns the chain of commands for this command, starting from the root, all the way to this command. +func (c *Command) getChain() []*Command { + var chain []*Command + for cmd := c; cmd != nil; cmd = cmd.parent { + chain = append([]*Command{cmd}, chain...) } + return chain +} - // Apply environment variables first - if err := c.applyEnvironmentVariables(envVars); err != nil { - return fmt.Errorf("failed to apply environment variables: %w", err) +func (c *Command) PrintHelp(w io.Writer, width int) error { + ww, err := NewWrappingWriter(width) + if err != nil { + return err } - // Override with CLI arguments - if err := c.flagSet.Parse(args); err != nil { - return fmt.Errorf("failed to apply CLI arguments: %w", err) + prefix4 := strings.Repeat(" ", 4) + prefix8 := strings.Repeat(" ", 8) + fullName := c.getFullName() + + // Command name & short description + if c.shortDescription != "" { + _, _ = fmt.Fprint(ww, fullName) + _, _ = fmt.Fprint(ww, ": ") + _ = ww.SetLinePrefix(prefix4) + _, _ = fmt.Fprintln(ww, c.shortDescription) + _ = ww.SetLinePrefix("") + } else { + _, _ = fmt.Fprintln(ww, fullName) } - - // Apply positional arguments - if c.positionalArgsTarget != nil { - *c.positionalArgsTarget = c.flagSet.Args() + _, _ = fmt.Fprintln(ww) + + // Long description if we have one + if c.longDescription != "" { + _, _ = fmt.Fprint(ww, "Description: ") + _ = ww.SetLinePrefix(prefix4) + _, _ = fmt.Fprintln(ww, c.longDescription) + _ = ww.SetLinePrefix("") + _, _ = fmt.Fprintln(ww) } - return nil -} - -func (c *Command) printCommandUsage(w io.Writer, short bool) { - cmdChain := c.Name - for cmd := c.parent; cmd != nil; cmd = cmd.parent { - cmdChain = cmd.Name + " " + cmdChain + // Usage line + _, _ = fmt.Fprintln(ww, "Usage:") + _ = ww.SetLinePrefix(prefix4) + _, _ = fmt.Fprint(ww, fullName+" ") + _ = ww.SetLinePrefix(prefix8) + if err := c.flags.printFlagsSingleLine(ww); err != nil { + return err } - - if !short { - _, _ = fmt.Fprintf(w, "%s: %s\n", cmdChain, c.ShortDescription) - _, _ = fmt.Fprintln(w) - if c.LongDescription != "" { - _, _ = fmt.Fprintf(w, "%s\n", c.LongDescription) - _, _ = fmt.Fprintln(w) + _ = ww.SetLinePrefix("") + _, _ = fmt.Fprintln(ww) + _, _ = fmt.Fprintln(ww) + + // Flags + if c.flags.hasFlags() { + _, _ = fmt.Fprintln(ww, "Flags:") + _ = ww.SetLinePrefix(prefix4) + if err := c.flags.printFlagsMultiLine(ww, prefix4); err != nil { + return err } + _ = ww.SetLinePrefix("") + _, _ = fmt.Fprintln(ww) } - flags := &bytes.Buffer{} - lenOfLongestFlagName := 0 - c.flagSet.VisitAll(func(f *flag.Flag) { - if len(f.Name) > lenOfLongestFlagName { - lenOfLongestFlagName = len(f.Name) - } - _, _ = fmt.Fprint(flags, " ") - required := slices.Contains(c.requiredFlags, f.Name) - if !required { - _, _ = fmt.Fprint(flags, "[") - } - if c.flagArgNames != nil { - if valueName, ok := c.flagArgNames[f.Name]; ok { - if valueName != "" { - _, _ = fmt.Fprintf(flags, "--%s=%s", f.Name, c.flagArgNames[f.Name]) - } else { - _, _ = fmt.Fprintf(flags, "--%s", f.Name) - } - } else { - _, _ = fmt.Fprintf(flags, "--%s=VALUE", f.Name) - } - } else { - _, _ = fmt.Fprintf(flags, "--%s=VALUE", f.Name) - } - if !required { - _, _ = fmt.Fprint(flags, "]") - } - }) - positionalArgs := "" - if c.positionalArgsTarget != nil { - positionalArgs = " [ARGS]" - } - - _, _ = fmt.Fprintf(w, "Usage:\n\t%s%s%s\n", cmdChain, flags, positionalArgs) - _, _ = fmt.Fprintln(w) + // Sub-commands + if len(c.subCommands) > 0 { + _, _ = fmt.Fprintln(ww, "Available sub-commands:") - if !short { - if lenOfLongestFlagName > 0 { - var usageStartColumn int - for usageStartColumn = 0; ; usageStartColumn += 10 { - if usageStartColumn > lenOfLongestFlagName { - break - } + lenOfLongestSubCommand := 0 + for _, subCmd := range c.subCommands { + if len(subCmd.name) > lenOfLongestSubCommand { + lenOfLongestSubCommand = len(subCmd.name) } - _, _ = fmt.Fprintf(w, "Flags:\n") - c.flagSet.VisitAll(func(f *flag.Flag) { - flagDesc := f.Usage - if f.DefValue != "" { - flagDesc += fmt.Sprintf(" (default is %s)", f.DefValue) - } - _, _ = fmt.Fprintf(w, "\t--%s%s%s\n", f.Name, strings.Repeat(" ", usageStartColumn-len(f.Name)), flagDesc) - }) - _, _ = fmt.Fprintln(w) } - - if len(c.subCommands) > 0 { - lenOfLongestSubcommandName := 0 - for _, subCmd := range c.subCommands { - if len(subCmd.Name) > lenOfLongestSubcommandName { - lenOfLongestSubcommandName = len(subCmd.Name) - } - } - var usageStartColumn int - for usageStartColumn = 0; ; usageStartColumn += 10 { - if usageStartColumn > lenOfLongestSubcommandName { - break - } - } - _, _ = fmt.Fprintf(w, "Available sub-commands:\n") - for _, subCmd := range c.subCommands { - _, _ = fmt.Fprintf(w, "\t%s%s%s\n", subCmd.Name, strings.Repeat(" ", usageStartColumn-len(subCmd.Name)), subCmd.ShortDescription) - } - _, _ = fmt.Fprintln(w) + subCommandNameDescSpacing := 10 - lenOfLongestSubCommand%10 + subCommandDescriptionCol := lenOfLongestSubCommand + subCommandNameDescSpacing + + for _, subCmd := range c.subCommands { + _ = ww.SetLinePrefix(prefix4) + _, _ = fmt.Fprint(ww, subCmd.name) + _, _ = fmt.Fprint(ww, strings.Repeat(" ", subCommandDescriptionCol-len(subCmd.name))) + _ = ww.SetLinePrefix(strings.Repeat(" ", len(prefix4)+subCommandDescriptionCol)) + _, _ = fmt.Fprintln(ww, subCmd.shortDescription) } - } -} + _, _ = fmt.Fprintln(ww) -func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) (exitCode int) { - if !root.createdByNewCommand { - panic("invalid root command given, indicating it may not have been created by 'command.New(...)'") } - // Iterate CLI args, separate them to flags & positional args, but also infer the command to execute from the given - // non-flags arguments; for example, assuming the following command line is given: - // - // cmd1 -flag1 sub1 something -flag2=1 sub2 -- sub3 -flag3 a b c - // - // And the command hierarchy is: cmd1 -> sub1 -> sub2 -> sub3 - // - // The returned values would be: - // command: sub2 (since it's the last valid command before the "--" which signals positional args only) - // flags: [-flag1, -flag2=1]: no "-flag3" because it's after the "--" separator - // positional args: [something, sub3, a, b, c]: no "cmd1", "sub1" and "sub2" as they are commands in the hierarchy - cmd, flagArgs, positionalArgs := inferCommandFlagsAndPositionals(root, args) - - // Build the command chain from top-to-bottom (so index 0 is the root) - commandChain := []*Command{cmd} - parent := cmd.parent - for parent != nil { - commandChain = append([]*Command{parent}, commandChain...) - parent = parent.parent + if _, err = w.Write([]byte(ww.String())); err != nil { + return err } + return nil +} - // Configure commands up the chain, in order to invoke their "PreSubCommandRun" function - for _, current := range commandChain { - if err := current.configure(envVars, append(flagArgs, positionalArgs...)); err != nil { - current.PrintShortUsage(w) - return 1 - } - - if current.PreSubCommandRun != nil { - if err := current.PreSubCommandRun(ctx, current.Config, current); err != nil { - _, _ = fmt.Fprintln(w, err.Error()) - return 1 - } - } +func (c *Command) PrintUsageLine(w io.Writer, width int) error { + ww, err := NewWrappingWriter(width) + if err != nil { + return err } - // If "--help" was provided, show usage and exit immediately - if cmd.flagSet.Lookup("help").Value.String() == "true" { - cmd.PrintFullUsage(w) - return 0 - } + prefix4 := strings.Repeat(" ", 4) + fullName := c.getFullName() - // If command has no "Run" function, it's an intermediate probably - just print its usage and exit successfully - if cmd.Run == nil { - cmd.PrintFullUsage(w) - return 0 + _, _ = fmt.Fprint(ww, "Usage: ") + _ = ww.SetLinePrefix(prefix4) + _, _ = fmt.Fprint(ww, fullName+" ") + if err := c.flags.printFlagsSingleLine(ww); err != nil { + return err } + _ = ww.SetLinePrefix("") + _, _ = fmt.Fprintln(ww) - // Run the command - if err := cmd.Run(ctx, cmd.Config, cmd); err != nil { - _, _ = fmt.Fprintln(w, err.Error()) - return 1 + if _, err = w.Write([]byte(ww.String())); err != nil { + return err } - - return 0 + return nil } diff --git a/command_test.go b/command_test.go index 285541b..07af71d 100644 --- a/command_test.go +++ b/command_test.go @@ -2,268 +2,426 @@ package command import ( "bytes" - "context" - "regexp" + "reflect" + "strings" "testing" . "github.com/arikkfir/justest" + "github.com/go-loremipsum/loremipsum" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) -func Test_initializeFlagSet(t *testing.T) { - type Flag struct { - Name string - Usage string - Value any - DefValue string - } +func TestNew(t *testing.T) { + t.Parallel() type testCase struct { - cmd *Command - expectedFlags []Flag - expectedFailure *string + commandFactory func(T, *testCase) (*Command, error) + expectedName string + expectedShortDescription string + expectedLongDescription string + expectedError string + expectedFlagSet *flagSet } testCases := map[string]testCase{ - "nil Config": {cmd: New(nil, Spec{Config: nil})}, - "Config is a pointer to a struct": { - cmd: New(nil, Spec{Config: &RootConfig{S0: "v0"}}), - expectedFlags: []Flag{ - {Name: "s0", DefValue: "v0", Usage: "String field", Value: nil}, - {Name: "b0", DefValue: "false", Usage: "Bool field", Value: nil}, + "empty name": { + commandFactory: func(t T, tc *testCase) (*Command, error) { + return New("", "short desc", "long desc", InlineExecutor{}) + }, + expectedError: `^invalid command: empty name$`, + }, + "empty short description": { + commandFactory: func(t T, tc *testCase) (*Command, error) { + return New("cmd", "", "long desc", InlineExecutor{}) + }, + expectedError: `^invalid command: empty short description$`, + }, + "nil executor": { + commandFactory: func(t T, tc *testCase) (*Command, error) { + return New("cmd", "desc", "long desc", nil) + }, + expectedError: `^invalid command: nil executor$`, + }, + "no flags": { + commandFactory: func(t T, tc *testCase) (*Command, error) { + return New("cmd", "desc", "long desc", InlineExecutor{}) + }, + expectedName: "cmd", + expectedShortDescription: "desc", + expectedLongDescription: "long desc", + }, + "with flags": { + commandFactory: func(t T, tc *testCase) (*Command, error) { + return New( + "cmd", + "desc", + "long desc", + &struct { + InlineExecutor + MyFlag string `flag:"true"` + }{}, + ) + }, + expectedFlagSet: &flagSet{ + flags: []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-flag", + HasValue: true, + }, + Targets: []reflect.Value{}, + }, + }, }, }, } for name, tc := range testCases { + tc := tc t.Run(name, func(t *testing.T) { - tc := tc - t.Run(name, func(t *testing.T) { - if tc.expectedFailure != nil { - defer func() { With(t).Verify(recover()).Will(Say(*tc.expectedFailure)).OrFail() }() - } - With(t).Verify(tc.cmd.initializeFlagSet()).Will(Succeed()).OrFail() - for _, expectedFlag := range tc.expectedFlags { - actualFlag := tc.cmd.flagSet.Lookup(expectedFlag.Name) - With(t).Verify(actualFlag).Will(Not(BeNil())).OrFail() - With(t).Verify(actualFlag.Usage).Will(EqualTo(expectedFlag.Usage)).OrFail() - With(t).Verify(actualFlag.DefValue).Will(EqualTo(expectedFlag.DefValue)).OrFail() + t.Parallel() + cmd, err := tc.commandFactory(t, &tc) + if tc.expectedError != "" { + With(t).Verify(err).Will(Fail(tc.expectedError)).OrFail() + } else { + With(t).Verify(err).Will(BeNil()).OrFail() + if tc.expectedFlagSet != nil { + With(t). + Verify(cmd.flags.flags). + Will(EqualTo( + tc.expectedFlagSet.flags, + cmpopts.IgnoreFields(flagDef{}, "Targets"), + cmp.AllowUnexported(flagDef{})), + ). + OrFail() } - }) + } }) } - t.Run("Config must be a pointer to a struct", func(t *testing.T) { - With(t). - Verify(New(nil, Spec{Config: RootConfig{S0: "v0"}}).initializeFlagSet()). - Will(Fail(`must not be a struct, but a pointer to a struct`)). - OrFail() - With(t). - Verify(New(nil, Spec{Config: 1}).initializeFlagSet()). - Will(Fail(`is not a pointer: 1`)). - OrFail() - With(t). - Verify(New(nil, Spec{Config: &[]int{123}[0]}).initializeFlagSet()). - Will(Fail(`is not a pointer to struct`)). - OrFail() - }) } -func Test_printCommandUsage(t *testing.T) { +func TestAddSubCommand(t *testing.T) { t.Parallel() - rootCmd := New(nil, Spec{ - Name: "root", - ShortDescription: "Root command", - LongDescription: "This command is the\nroot command.", - Config: &RootConfig{}, - }) - sub1Cmd := New(rootCmd, Spec{ - Name: "sub1", - ShortDescription: "Sub command 1", - LongDescription: "This command is the\nfirst sub command.", - Config: &Sub1Config{}, - }) - sub2Cmd := New(rootCmd, Spec{ - Name: "sub2", - ShortDescription: "Sub command 2", - LongDescription: "This command is the\nsecond sub command.", - Config: &Sub2Config{}, - }) - sub3Cmd := New(sub2Cmd, Spec{ - Name: "sub3", - ShortDescription: "Sub command 3", - LongDescription: "This command is the\nthird sub command.", - Config: &Sub3Config{}, - }) - - type testCase struct { - cmd *Command - expectedUsage string - } - - testCases := map[string]testCase{ - rootCmd.Name: { - cmd: rootCmd, - expectedUsage: ` -root: Root command - -This command is the -root command. + root, err := New("root", "desc", "description", &InlineExecutor{}) + With(t).Verify(err).Will(BeNil()).OrFail() -Usage: - root [--b0] [--help] --s0=VAL + sub1, err := New("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}) + With(t).Verify(err).Will(BeNil()).OrFail() -Flags: - --b0 Bool field (default is false) - --help Show help about how to use this command (default is false) - --s0 String field + sub2, err := New("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}) + With(t).Verify(err).Will(BeNil()).OrFail() -Available sub-commands: - sub1 Sub command 1 - sub2 Sub command 2 + With(t).Verify(root.AddSubCommand(sub1)).Will(BeNil()).OrFail() + With(t).Verify(root.AddSubCommand(sub2)).Will(BeNil()).OrFail() + With(t).Verify(root.subCommands[0], root.subCommands[1]).Will(EqualTo(sub1, sub2, cmpopts.EquateComparable(&Command{}))).OrFail() + With(t).Verify(sub1.parent).Will(EqualTo(root, cmpopts.EquateComparable(&Command{}))).OrFail() + With(t).Verify(sub2.parent).Will(EqualTo(root, cmpopts.EquateComparable(&Command{}))).OrFail() +} -`, +func Test_inferCommandAndArgs(t *testing.T) { + type testCase struct { + root *Command + args []string + expectedCommand string + expectedFlags []string + expectedPositionals []string + } + testCases := map[string]testCase{ + "No arguments": { + root: MustNew( + "root", "desc", "description", &InlineExecutor{}, + MustNew("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}, + MustNew("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}, + MustNew("sub3", "sub3 desc", "sub3 description", &InlineExecutor{}), + ), + ), + ), + args: []string{}, + expectedCommand: "root", + expectedFlags: nil, + expectedPositionals: nil, }, - sub1Cmd.Name: { - cmd: sub1Cmd, - expectedUsage: ` -root sub1: Sub command 1 - -This command is the -first sub command. - -Usage: - root sub1 [--b0] [--b1] [--help] --s0=VAL --s1=VALUE - -Flags: - --b0 Bool field (default is false) - --b1 Bool field (default is false) - --help Show help about how to use this command (default is false) - --s0 String field - --s1 String field - -`, + "Flags for root command": { + root: MustNew( + "root", "desc", "description", &InlineExecutor{}, + MustNew("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}, + MustNew("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}), + ), + ), + args: strings.Split("-f1 -f2", " "), + expectedCommand: "root", + expectedFlags: []string{"-f1", "-f2"}, + expectedPositionals: nil, }, - sub2Cmd.Name: { - cmd: sub2Cmd, - expectedUsage: ` -root sub2: Sub command 2 - -This command is the -second sub command. - -Usage: - root sub2 [--b0] [--b1] [--b2] [--help] --s0=VAL --s1=VALUE [--s2=VALUE] - -Flags: - --b0 Bool field (default is false) - --b1 Bool field (default is false) - --b2 Bool field (default is false) - --help Show help about how to use this command (default is false) - --s0 String field - --s1 String field - --s2 String field - -Available sub-commands: - sub3 Sub command 3 - -`, + "Flags and positionals for root command": { + root: MustNew( + "root", "desc", "description", &InlineExecutor{}, + MustNew("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}, + MustNew("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}), + ), + ), + args: strings.Split("-f1 a -f2 b", " "), + expectedCommand: "root", + expectedFlags: []string{"-f1", "-f2"}, + expectedPositionals: []string{"a", "b"}, }, - sub3Cmd.Name: { - cmd: sub3Cmd, - expectedUsage: ` -root sub2 sub3: Sub command 3 - -This command is the -third sub command. - -Usage: - root sub2 sub3 [--b0] [--b1] [--b2] [--b3] [--help] --s0=VAL --s1=VALUE [--s2=VALUE] [--s3=VALUE] [ARGS] - -Flags: - --b0 Bool field (default is false) - --b1 Bool field (default is false) - --b2 Bool field (default is false) - --b3 Bool field (default is false) - --help Show help about how to use this command (default is false) - --s0 String field - --s1 String field - --s2 String field - --s3 String field - -`, + "Flags and positionals for sub1 command": { + root: MustNew( + "root", "desc", "description", &InlineExecutor{}, + MustNew("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}, + MustNew("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}), + ), + ), + args: strings.Split("-f1 sub1 -f2 a b", " "), + expectedCommand: "sub1", + expectedFlags: []string{"-f1", "-f2"}, + expectedPositionals: []string{"a", "b"}, + }, + "Flags and positionals for sub2 command": { + root: MustNew( + "root", "desc", "description", &InlineExecutor{}, + MustNew("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}, + MustNew("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}), + ), + ), + args: strings.Split("-f1 sub1 -f2 a b sub2 c", " "), + expectedCommand: "sub2", + expectedFlags: []string{"-f1", "-f2"}, + expectedPositionals: []string{"a", "b", "c"}, }, } for name, tc := range testCases { tc := tc t.Run(name, func(t *testing.T) { - t.Parallel() - usageBuf := &bytes.Buffer{} - - With(t).Verify(tc.cmd.initializeFlagSet()).Will(Succeed()).OrFail() - tc.cmd.printCommandUsage(usageBuf, false) - With(t).Verify(usageBuf.String()).Will(EqualTo(tc.expectedUsage[1:])).OrFail() + flags, positionals, cmd := tc.root.inferCommandAndArgs(tc.args) + With(t).Verify(flags).Will(EqualTo(tc.expectedFlags)).OrFail() + With(t).Verify(positionals).Will(EqualTo(tc.expectedPositionals)).OrFail() + With(t).Verify(cmd.name).Will(EqualTo(tc.expectedCommand)).OrFail() }) } } -func TestExecute(t *testing.T) { - t.Parallel() - - rootSpec := Spec{Name: "root", ShortDescription: "Root!", LongDescription: "The root command.", Config: &RootConfig{}} - sub1Spec := Spec{Name: "sub1", ShortDescription: "Sub 1", LongDescription: "The first sub command.", Config: &Sub1Config{}} - sub2Spec := Spec{Name: "sub2", ShortDescription: "Sub 2", LongDescription: "The second sub command.", Config: &Sub2Config{}} - sub3Spec := Spec{Name: "sub3", ShortDescription: "Sub 3", LongDescription: "The third sub command.", Config: &Sub3Config{}} +func Test_getFullName(t *testing.T) { + type testCase struct { + cmd *Command + expectedFullName string + } + sub3 := MustNew("sub3", "sub3 desc", "sub3 description", &InlineExecutor{}) + sub2 := MustNew("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}, sub3) + sub1 := MustNew("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}, sub2) + root := MustNew("root", "desc", "description", &InlineExecutor{}, sub1) + testCases := map[string]testCase{ + "root": { + cmd: root, + expectedFullName: "root", + }, + "sub1": { + cmd: sub1, + expectedFullName: "root sub1", + }, + "sub2": { + cmd: sub2, + expectedFullName: "root sub1 sub2", + }, + "sub3": { + cmd: sub3, + expectedFullName: "root sub1 sub2 sub3", + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + With(t).Verify(tc.cmd.getFullName()).Will(EqualTo(tc.expectedFullName)).OrFail() + }) + } +} +func Test_getChain(t *testing.T) { type testCase struct { - args []string - envVars []string - expectedExitCode int - expectedPreRunCalls map[string]bool - expectedRunCalls map[string]bool - expectedOutput string + cmd *Command + expectedChain []string } + sub3 := MustNew("sub3", "sub3 desc", "sub3 description", &InlineExecutor{}) + sub2 := MustNew("sub2", "sub2 desc", "sub2 description", &InlineExecutor{}, sub3) + sub1 := MustNew("sub1", "sub1 desc", "sub1 description", &InlineExecutor{}, sub2) + root := MustNew("root", "desc", "description", &InlineExecutor{}, sub1) testCases := map[string]testCase{ - "": { - args: nil, - envVars: nil, - expectedExitCode: 0, - expectedPreRunCalls: map[string]bool{"root": true}, - expectedRunCalls: map[string]bool{"root": true}, + "root": { + cmd: root, + expectedChain: []string{"root"}, }, "sub1": { - args: []string{"sub1"}, - envVars: nil, - expectedExitCode: 0, - expectedPreRunCalls: map[string]bool{"root": true, "sub1": true}, - expectedRunCalls: map[string]bool{"sub1": true}, + cmd: sub1, + expectedChain: []string{"root", "sub1"}, + }, + "sub2": { + cmd: sub2, + expectedChain: []string{"root", "sub1", "sub2"}, }, - "sub2 sub3": { - args: []string{"sub2", "sub3"}, - envVars: nil, - expectedExitCode: 0, - expectedPreRunCalls: map[string]bool{"root": true, "sub2": true, "sub3": true}, - expectedRunCalls: map[string]bool{"sub3": true}, + "sub3": { + cmd: sub3, + expectedChain: []string{"root", "sub1", "sub2", "sub3"}, }, - "sub2 sub3 --help": { - args: []string{"sub2", "sub3", "--help"}, - envVars: nil, - expectedExitCode: 0, - expectedPreRunCalls: map[string]bool{"root": true, "sub2": true, "sub3": true}, - expectedRunCalls: map[string]bool{}, - expectedOutput: `root sub2 sub3: Sub 3 + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + var chainNames []string + for _, cmd := range tc.cmd.getChain() { + chainNames = append(chainNames, cmd.name) + } + With(t).Verify(chainNames).Will(EqualTo(tc.expectedChain)).OrFail() + }) + } +} + +func TestPrintHelp(t *testing.T) { + t.Parallel() -The third sub command. + type testCase struct { + commandFactory func(*testCase) *Command + expectedHelpOutput string + expectedHelpUsageOutput string + } + testCases := map[string]testCase{ + "no flags & no positionals": { + commandFactory: func(*testCase) *Command { + ligen := loremipsum.NewWithSeed(4321) + return MustNew("cmd", ligen.Sentence(), ligen.Sentences(2), InlineExecutor{}) + }, + expectedHelpUsageOutput: ` +Usage: cmd [--help] +`, + expectedHelpOutput: ` +cmd: Lorem ipsum dolor sit amet consectetur + adipiscing elit ac, purus molestie luctus nec + neque cursus conubia vehicula rutrum primis + laoreet vivamus sed nisl lobortis efficitur + ultrices. + +Description: Lorem ipsum dolor sit amet + consectetur adipiscing elit ac, purus + molestie luctus nec. Urna magnis platea risus + habitant diam pellentesque per mauris + consequat, nec ex dis vehicula convallis + habitasse vel molestie auctor suspendisse + efficitur rutrum praesent eleifend quisque + volutpat curae quis lectus. Usage: - root sub2 sub3 [--b0] [--b1] [--b2] [--b3] [--help] --s0=VAL --s1=VALUE [--s2=VALUE] [--s3=VALUE] [ARGS] + cmd [--help] Flags: - --b0 Bool field (default is false) - --b1 Bool field (default is false) - --b2 Bool field (default is false) - --b3 Bool field (default is false) - --help Show help about how to use this command (default is false) - --s0 String field - --s1 String field - --s2 String field - --s3 String field + [--help] Show this help screen and exit. + (default value: false, environment + variable: HELP) + +`, + }, + "with flags, args": { + commandFactory: func(*testCase) *Command { + ligen := loremipsum.NewWithSeed(4321) + return MustNew("cmd", ligen.Sentence(), ligen.Sentences(2), &struct { + InlineExecutor + MyFlag string `desc:"flag description"` + Args []string `args:"true"` + }{}) + }, + expectedHelpUsageOutput: ` +Usage: cmd [--help] + [--my-flag=VALUE] + [ARGS...] +`, + expectedHelpOutput: ` +cmd: Lorem ipsum dolor sit amet consectetur + adipiscing elit ac, purus molestie luctus nec + neque cursus conubia vehicula rutrum primis + laoreet vivamus sed nisl lobortis efficitur + ultrices. + +Description: Lorem ipsum dolor sit amet + consectetur adipiscing elit ac, purus + molestie luctus nec. Urna magnis platea risus + habitant diam pellentesque per mauris + consequat, nec ex dis vehicula convallis + habitasse vel molestie auctor suspendisse + efficitur rutrum praesent eleifend quisque + volutpat curae quis lectus. + +Usage: + cmd [--help] [--my-flag=VALUE] [ARGS...] + +Flags: + [--help] Show this help screen and + exit. (default value: + false, environment + variable: HELP) + [--my-flag=VALUE] flag description + (environment variable: + MY_FLAG) + +`, + }, + "with sub-commands": { + commandFactory: func(*testCase) *Command { + ligen := loremipsum.NewWithSeed(4321) + return MustNew( + "cmd", + ligen.Sentence(), + ligen.Sentences(2), + &struct { + InlineExecutor + MyFlag string `desc:"flag description"` + Args []string `args:"true"` + }{}, + MustNew( + "child1", ligen.Sentence(), ligen.Sentences(2), &struct { + InlineExecutor + SubFlag string `desc:"sub flag description"` + Args []string `args:"true"` + }{}, + ), + ) + }, + expectedHelpUsageOutput: ` +Usage: cmd [--help] + [--my-flag=VALUE] + [ARGS...] +`, + expectedHelpOutput: ` +cmd: Lorem ipsum dolor sit amet consectetur + adipiscing elit ac, purus molestie luctus nec + neque cursus conubia vehicula rutrum primis + laoreet vivamus sed nisl lobortis efficitur + ultrices. + +Description: Lorem ipsum dolor sit amet + consectetur adipiscing elit ac, purus + molestie luctus nec. Urna magnis platea risus + habitant diam pellentesque per mauris + consequat, nec ex dis vehicula convallis + habitasse vel molestie auctor suspendisse + efficitur rutrum praesent eleifend quisque + volutpat curae quis lectus. + +Usage: + cmd [--help] [--my-flag=VALUE] [ARGS...] + +Flags: + [--help] Show this help screen and + exit. (default value: + false, environment + variable: HELP) + [--my-flag=VALUE] flag description + (environment variable: + MY_FLAG) + +Available sub-commands: + child1 Et dolor viverra nulla ipsum + finibus curae conubia gravida + elementum litora eleifend class + porttitor morbi nisi mus non + consequat pharetra convallis + bibendum rhoncus etiam. `, }, @@ -273,62 +431,15 @@ Flags: t.Run(name, func(t *testing.T) { t.Parallel() - preRunCalls := make(map[string]bool) - runCalls := make(map[string]bool) - - rootSpec := rootSpec - rootSpec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - preRunCalls["root"] = true - return nil - } - rootSpec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - runCalls["root"] = true - return nil - } - rootCmd := New(nil, rootSpec) - - sub1Spec := sub1Spec - sub1Spec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - preRunCalls["sub1"] = true - return nil - } - sub1Spec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - runCalls["sub1"] = true - return nil - } - sub1Cmd := New(rootCmd, sub1Spec) - - sub2Spec := sub2Spec - sub2Spec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - preRunCalls["sub2"] = true - return nil - } - sub2Spec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - runCalls["sub2"] = true - return nil - } - sub2Cmd := New(rootCmd, sub2Spec) + cmd := tc.commandFactory(&tc) + b := &bytes.Buffer{} - sub3Spec := sub3Spec - sub3Spec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - preRunCalls["sub3"] = true - return nil - } - sub3Spec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { - runCalls["sub3"] = true - return nil - } - sub3Cmd := New(sub2Cmd, sub3Spec) - _, _, _, _ = rootCmd, sub1Cmd, sub2Cmd, sub3Cmd + With(t).Verify(cmd.PrintHelp(b, 50)).Will(Succeed()).OrFail() + With(t).Verify(b.String()).Will(EqualTo(tc.expectedHelpOutput[1:])).OrFail() - b := &bytes.Buffer{} - exitCode := Execute(context.Background(), b, rootCmd, tc.args, EnvVarsArrayToMap(tc.envVars)) - With(t).Verify(exitCode).Will(EqualTo(tc.expectedExitCode)).OrFail() - With(t).Verify(preRunCalls).Will(EqualTo(tc.expectedPreRunCalls)).OrFail() - With(t).Verify(runCalls).Will(EqualTo(tc.expectedRunCalls)).OrFail() - if tc.expectedOutput != "" { - With(t).Verify(b).Will(Say(`^` + regexp.QuoteMeta(tc.expectedOutput) + `$`)).OrFail() - } + b.Reset() + With(t).Verify(cmd.PrintUsageLine(b, 30)).Will(Succeed()).OrFail() + With(t).Verify(b.String()).Will(EqualTo(tc.expectedHelpUsageOutput[1:])).OrFail() }) } } diff --git a/configs_for_test.go b/configs_for_test.go deleted file mode 100644 index ddf5db0..0000000 --- a/configs_for_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package command - -type RootConfig struct { - S0 string `flag:"valueName=VAL,required" desc:"String field"` - B0 bool `desc:"Bool field"` -} -type Sub1Config struct { - RootConfig - S1 string `flag:"required" desc:"String field"` - B1 bool `desc:"Bool field"` -} -type Sub2Config struct { - Sub1Config - S2 string `desc:"String field"` - B2 bool `desc:"Bool field"` -} -type Sub3Config struct { - Sub2Config - S3 string `desc:"String field"` - B3 bool `desc:"Bool field"` - Args []string `flag:"args" desc:"Arbitrary arguments"` -} diff --git a/execute.go b/execute.go new file mode 100644 index 0000000..78a9a62 --- /dev/null +++ b/execute.go @@ -0,0 +1,91 @@ +package command + +import ( + "context" + "errors" + "fmt" + "io" +) + +type ExitCode int + +const ( + ExitCodeSuccess ExitCode = 0 + ExitCodeError ExitCode = 1 + ExitCodeMisconfiguration ExitCode = 2 +) + +// Executor is the interface to be implemented by custom commands. +type Executor interface { + PreRun(ctx context.Context) error + Run(ctx context.Context) error +} + +type InlineExecutor struct { + PreRunFunc func(context.Context) error + RunFunc func(context.Context) error +} + +func (i InlineExecutor) PreRun(ctx context.Context) error { + if i.PreRunFunc != nil { + return i.PreRunFunc(ctx) + } else { + return nil + } +} + +func (i InlineExecutor) Run(ctx context.Context) error { + if i.PreRunFunc != nil { + return i.RunFunc(ctx) + } else { + return nil + } +} + +// Execute the correct command in the given command hierarchy (starting at "root"), configured from the given CLI args +// and environment variables. The command will be executed with the given context after all pre-RunFunc hooks have been +// successfully executed in the command hierarchy. +func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) ExitCode { + if root.parent != nil { + _, _ = fmt.Fprintf(w, "%s: command must be the root command", errors.ErrUnsupported) + return ExitCodeError + } + + // Extract the command, CLI flags, positional arguments & the command hierarchy + flags, positionals, cmd := root.inferCommandAndArgs(args) + + // Create flagSet & apply it to the configuration structs + // If "--help" is given, print help and exit + if err := cmd.flags.apply(envVars, append(flags, positionals...)); err != nil { + _, _ = fmt.Fprintln(w, err) + if err := cmd.PrintUsageLine(w, getTerminalWidth()); err != nil { + _, _ = fmt.Fprintf(w, "%s\n", err) + return ExitCodeError + } else { + return ExitCodeMisconfiguration + } + } else if cmd.HelpConfig.Help { + if err := cmd.PrintHelp(w, getTerminalWidth()); err != nil { + _, _ = fmt.Fprintf(w, "%s\n", err) + return ExitCodeError + } else { + return ExitCodeSuccess + } + } + + // Invoke all "PreRun" hooks on the whole chain of commands (starting at the root) + for _, c := range cmd.getChain() { + if err := c.executor.PreRun(ctx); err != nil { + _, _ = fmt.Fprintln(w, err) + return ExitCodeError + } + } + + // Run the command + if err := cmd.executor.Run(ctx); err != nil { + _, _ = fmt.Fprintln(w, err) + return ExitCodeError + } + + return ExitCodeSuccess +} diff --git a/execute_test.go b/execute_test.go new file mode 100644 index 0000000..16019a0 --- /dev/null +++ b/execute_test.go @@ -0,0 +1,115 @@ +package command + +import ( + "bytes" + "context" + "os" + "testing" + "time" + + . "github.com/arikkfir/justest" +) + +type TrackingExecutor struct { + preRunCalled *time.Time + preRunErrorToReturn error + runCalled *time.Time + runErrorToReturn error +} + +func (te *TrackingExecutor) PreRun(_ context.Context) error { + te.preRunCalled = ptrOf(time.Now()) + time.Sleep(100 * time.Millisecond) + return te.preRunErrorToReturn +} + +func (te *TrackingExecutor) Run(_ context.Context) error { + te.runCalled = ptrOf(time.Now()) + time.Sleep(100 * time.Millisecond) + return te.runErrorToReturn +} + +type ExecutorWithFlag struct { + MyFlag string `name:"my-flag"` +} + +func (e *ExecutorWithFlag) PreRun(_ context.Context) error { + return nil +} + +func (e *ExecutorWithFlag) Run(_ context.Context) error { + return nil +} + +func TestExecute(t *testing.T) { + t.Parallel() + + t.Run("command must be root", func(t *testing.T) { + ctx := context.Background() + child := MustNew("child", "desc", "long desc", InlineExecutor{}) + _ = MustNew("root", "desc", "long desc", InlineExecutor{}, child) + b := &bytes.Buffer{} + With(t).Verify(Execute(ctx, b, child, nil, nil)).Will(EqualTo(ExitCodeError)).OrFail() + With(t).Verify(b).Will(Say(`^unsupported operation: command must be the root command$`)).OrFail() + }) + + t.Run("applies configuration", func(t *testing.T) { + ctx := context.Background() + cmd := MustNew("cmd", "desc", "long desc", &ExecutorWithFlag{}) + With(t).Verify(Execute(ctx, os.Stderr, cmd, []string{"--my-flag=V1"}, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail() + With(t).Verify(cmd.executor.(*ExecutorWithFlag).MyFlag).Will(EqualTo("V1")).OrFail() + }) + + t.Run("prints usage on CLI parse errors", func(t *testing.T) { + ctx := context.Background() + cmd := MustNew("cmd", "desc", "long desc", &ExecutorWithFlag{}) + b := &bytes.Buffer{} + With(t).Verify(Execute(ctx, b, cmd, []string{"--bad-flag=V1"}, nil)).Will(EqualTo(ExitCodeMisconfiguration)).OrFail() + With(t).Verify(cmd.executor.(*ExecutorWithFlag).MyFlag).Will(BeEmpty()).OrFail() + With(t).Verify(b.String()).Will(EqualTo("unknown flag: --bad-flag\nUsage: cmd [--help] [--my-flag=VALUE]\n")).OrFail() + }) + + t.Run("prints help on --help flag", func(t *testing.T) { + ctx := context.Background() + cmd := MustNew("cmd", "desc", "long desc", &ExecutorWithFlag{}) + b := &bytes.Buffer{} + With(t).Verify(Execute(ctx, b, cmd, []string{"--help"}, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail() + With(t).Verify(b.String()).Will(EqualTo(` +cmd: desc + +Description: long desc + +Usage: + cmd [--help] [--my-flag=VALUE] + +Flags: + [--help] Show this help screen and exit. (default value: false, + environment variable: HELP) + [--my-flag=VALUE] environment variable: MY_FLAG + +`[1:])).OrFail() + }) + + t.Run("preRun called for command chain", func(t *testing.T) { + ctx := context.Background() + sub2 := MustNew("sub2", "desc", "long desc", &TrackingExecutor{}) + sub1 := MustNew("sub1", "desc", "long desc", &TrackingExecutor{}, sub2) + root := MustNew("cmd", "desc", "long desc", &TrackingExecutor{}, sub1) + With(t).Verify(Execute(ctx, os.Stderr, root, []string{"sub1", "sub2"}, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail() + + sub2PreRunTime := sub2.executor.(*TrackingExecutor).preRunCalled + With(t).Verify(sub2PreRunTime).Will(Not(BeNil())).OrFail() + + sub1PreRunTime := sub1.executor.(*TrackingExecutor).preRunCalled + With(t).Verify(sub1PreRunTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub1PreRunTime.Before(*sub2PreRunTime)).Will(EqualTo(true)).OrFail() + + rootPreRunTime := root.executor.(*TrackingExecutor).preRunCalled + With(t).Verify(rootPreRunTime).Will(Not(BeNil())).OrFail() + With(t).Verify(rootPreRunTime.Before(*sub1PreRunTime)).Will(EqualTo(true)).OrFail() + + sub2RunTime := sub2.executor.(*TrackingExecutor).runCalled + With(t).Verify(sub2RunTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2RunTime.After(*sub2PreRunTime)).Will(EqualTo(true)).OrFail() + }) +} diff --git a/flag_def.go b/flag_def.go new file mode 100644 index 0000000..264993f --- /dev/null +++ b/flag_def.go @@ -0,0 +1,166 @@ +package command + +import ( + "cmp" + "errors" + "fmt" + "reflect" + "strconv" +) + +type ErrInvalidValue struct { + Cause error + Value string + Flag string +} + +func (e *ErrInvalidValue) Error() string { + return fmt.Sprintf("invalid value '%s' for flag '%s': %s", e.Value, e.Flag, e.Cause) +} + +func (e *ErrInvalidValue) Unwrap() error { + return e.Cause +} + +type flagInfo struct { + Name string + EnvVarName *string + HasValue bool + ValueName *string + Description *string + Required *bool + DefaultValue string +} + +type flagDef struct { + flagInfo + Inherited bool + Targets []reflect.Value + applied bool +} + +func (fd *flagDef) isRequired() bool { + return fd.Required != nil && *fd.Required +} + +func (fd *flagDef) getValueName() string { + if fd.HasValue { + if fd.ValueName != nil { + return *fd.ValueName + } else { + return "VALUE" + } + } else { + return "" + } +} + +func (fd *flagDef) setValue(sv string) error { + for _, fv := range fd.Targets { + switch fv.Kind() { + case reflect.Bool: + if b, err := strconv.ParseBool(sv); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + return &ErrInvalidValue{Cause: ne.Err, Value: ne.Num, Flag: fd.Name} + } else { + return &ErrInvalidValue{Cause: err, Value: sv, Flag: fd.Name} + } + } else { + fv.SetBool(b) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if i, err := strconv.ParseInt(sv, 10, 64); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + return &ErrInvalidValue{Cause: ne.Err, Value: ne.Num, Flag: fd.Name} + } else { + return &ErrInvalidValue{Cause: err, Value: sv, Flag: fd.Name} + } + } else { + fv.SetInt(i) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if ui, err := strconv.ParseUint(sv, 10, 64); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + return &ErrInvalidValue{Cause: ne.Err, Value: ne.Num, Flag: fd.Name} + } else { + return &ErrInvalidValue{Cause: err, Value: sv, Flag: fd.Name} + } + } else { + fv.SetUint(ui) + } + case reflect.Float32, reflect.Float64: + if f, err := strconv.ParseFloat(sv, 64); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + return &ErrInvalidValue{Cause: ne.Err, Value: ne.Num, Flag: fd.Name} + } else { + return &ErrInvalidValue{Cause: err, Value: sv, Flag: fd.Name} + } + } else { + fv.SetFloat(f) + } + case reflect.String: + fv.SetString(sv) + default: + return fmt.Errorf("%w: field kind is '%s'", errors.ErrUnsupported, fv.Kind()) + } + } + fd.applied = true + return nil +} + +func (fd *flagDef) isLessThan(b *flagDef) bool { + a := fd + name := cmp.Compare(a.Name, b.Name) + if name < 0 { + return true + } else if name > 0 { + return false + } + envVarName := cmp.Compare(defaultIfNil(a.EnvVarName, ""), defaultIfNil(b.EnvVarName, "")) + if envVarName < 0 { + return true + } else if envVarName > 0 { + return false + } + hasValue := cmp.Compare(intForBool(a.HasValue), intForBool(b.HasValue)) + if hasValue < 0 { + return true + } else if hasValue > 0 { + return false + } + valueName := cmp.Compare(defaultIfNil(a.ValueName, ""), defaultIfNil(b.ValueName, "")) + if valueName < 0 { + return true + } else if valueName > 0 { + return false + } + description := cmp.Compare(defaultIfNil(a.Description, ""), defaultIfNil(b.Description, "")) + if description < 0 { + return true + } else if description > 0 { + return false + } + required := cmp.Compare(intForBool(defaultIfNil(a.Required, false)), intForBool(defaultIfNil(b.Required, false))) + if required < 0 { + return true + } else if required > 0 { + return false + } + defaultValue := cmp.Compare(a.DefaultValue, b.DefaultValue) + if defaultValue < 0 { + return true + } else if defaultValue > 0 { + return false + } + inherited := cmp.Compare(intForBool(a.Inherited), intForBool(b.Inherited)) + if inherited < 0 { + return true + } else if inherited > 0 { + return false + } + return false +} diff --git a/flag_def_test.go b/flag_def_test.go new file mode 100644 index 0000000..1db77c3 --- /dev/null +++ b/flag_def_test.go @@ -0,0 +1,311 @@ +package command + +import ( + "math" + "reflect" + "strconv" + "testing" + + . "github.com/arikkfir/justest" +) + +func TestFlagDefIsRequired(t *testing.T) { + t.Parallel() + + type testCase struct { + fd *flagDef + expectedRequired bool + } + + testCases := map[string]testCase{ + "nil": {fd: &flagDef{flagInfo: flagInfo{Name: "my-flag"}}, expectedRequired: false}, + "*true": {fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{true}[0]}}, expectedRequired: true}, + "*false": {fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{false}[0]}}, expectedRequired: false}, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + With(t).Verify(tc.fd.isRequired()).Will(EqualTo(tc.expectedRequired)).OrFail() + }) + } +} + +func TestFlagDefGetValueName(t *testing.T) { + t.Parallel() + + type testCase struct { + fd *flagDef + expectedValueName string + } + + testCases := map[string]testCase{ + "does not have value": {fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: false}}, expectedValueName: ""}, + "has value & has value name": {fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: true, ValueName: &[]string{"VVV"}[0]}}, expectedValueName: "VVV"}, + "has value & has no value name": {fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: true}}, expectedValueName: "VALUE"}, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + With(t).Verify(tc.fd.getValueName()).Will(EqualTo(tc.expectedValueName)).OrFail() + }) + } +} + +func TestFlagDefSetValue(t *testing.T) { + t.Parallel() + type Target struct { + B bool + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + S string + } + type testCase struct { + target *Target + targetsFactory func(tc *testCase) []reflect.Value + value string + expectedTarget Target + expectedError string + } + testCases := map[string]testCase{ + "valid bool": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("B")} + }, + value: "true", + expectedTarget: Target{B: true}, + }, + "invalid bool": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("B")} + }, + value: "bad bool", + expectedError: `^invalid value 'bad bool' for flag 'my-flag': invalid syntax$`, + }, + "valid int": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I")} + }, + value: strconv.FormatInt(math.MaxInt, 10), + expectedTarget: Target{I: math.MaxInt}, + }, + "invalid int": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid int8": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I8")} + }, + value: strconv.FormatInt(math.MaxInt8, 10), + expectedTarget: Target{I8: math.MaxInt8}, + }, + "invalid int8": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I8")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid int16": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I16")} + }, + value: strconv.FormatInt(math.MaxInt16, 10), + expectedTarget: Target{I16: math.MaxInt16}, + }, + "invalid int16": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I16")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid int32": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I32")} + }, + value: strconv.FormatInt(math.MaxInt32, 10), + expectedTarget: Target{I32: math.MaxInt32}, + }, + "invalid int32": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I32")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid int64": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I64")} + }, + value: strconv.FormatInt(math.MaxInt64, 10), + expectedTarget: Target{I64: math.MaxInt64}, + }, + "invalid int64": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("I64")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid uint": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI")} + }, + value: strconv.FormatUint(math.MaxUint, 10), + expectedTarget: Target{UI: math.MaxUint}, + }, + "invalid uint": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid uint8": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI8")} + }, + value: strconv.FormatUint(math.MaxUint8, 10), + expectedTarget: Target{UI8: math.MaxUint8}, + }, + "invalid uint8": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI8")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid uint16": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI16")} + }, + value: strconv.FormatUint(math.MaxUint16, 10), + expectedTarget: Target{UI16: math.MaxUint16}, + }, + "invalid uint16": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI16")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid uint32": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI32")} + }, + value: strconv.FormatUint(math.MaxUint32, 10), + expectedTarget: Target{UI32: math.MaxUint32}, + }, + "invalid uint32": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI32")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid uint64": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI64")} + }, + value: strconv.FormatUint(math.MaxUint64, 10), + expectedTarget: Target{UI64: math.MaxUint64}, + }, + "invalid uint64": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("UI64")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid float32": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("F32")} + }, + value: strconv.FormatFloat(math.MaxFloat32, 'g', -1, 64), + expectedTarget: Target{F32: math.MaxFloat32}, + }, + "invalid float32": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("F32")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "valid float64": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("F64")} + }, + value: strconv.FormatFloat(math.MaxFloat64, 'g', -1, 64), + expectedTarget: Target{F64: math.MaxFloat64}, + }, + "invalid float64": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("F64")} + }, + value: "abc", + expectedError: `^invalid value 'abc' for flag 'my-flag': invalid syntax$`, + }, + "string": { + target: &Target{}, + targetsFactory: func(tc *testCase) []reflect.Value { + return []reflect.Value{reflect.ValueOf(tc.target).Elem().FieldByName("S")} + }, + value: "abc", + expectedTarget: Target{S: "abc"}, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + fd := &flagDef{flagInfo: flagInfo{Name: "my-flag"}, Targets: tc.targetsFactory(&tc)} + err := fd.setValue(tc.value) + if tc.expectedError != "" { + With(t).Verify(err).Will(Fail(tc.expectedError)).OrFail() + } else { + With(t).Verify(err).Will(BeNil()).OrFail() + With(t).Verify(*tc.target).Will(EqualTo(tc.expectedTarget)).OrFail() + } + }) + } +} diff --git a/flag_merged.go b/flag_merged.go new file mode 100644 index 0000000..09b504f --- /dev/null +++ b/flag_merged.go @@ -0,0 +1,102 @@ +package command + +import ( + "fmt" +) + +type mergedFlagDef struct { + flagInfo + applied bool + flagDefs []*flagDef +} + +func (mfd *mergedFlagDef) addFlagDef(fd *flagDef) error { + if fd.Name != mfd.Name { + return fmt.Errorf("given flag '%s' has incompatible name - must be '%s'", fd.Name, mfd.Name) + } + + if mfd.EnvVarName == nil { + if fd.EnvVarName != nil { + mfd.EnvVarName = fd.EnvVarName + } + } else if fd.EnvVarName != nil { + if *mfd.EnvVarName != *fd.EnvVarName { + return fmt.Errorf("flag '%s' has incompatible environment variable name '%v' - must be '%v'", fd.Name, *fd.EnvVarName, *mfd.EnvVarName) + } + } + + if fd.HasValue != mfd.HasValue { + if mfd.HasValue { + return fmt.Errorf("given flag '%s' must have a value, but it does not", fd.Name) + } else { + return fmt.Errorf("given flag '%s' must not have a value, but it does", fd.Name) + } + } + + if mfd.ValueName == nil { + if fd.ValueName != nil { + mfd.ValueName = fd.ValueName + } + } else if fd.ValueName != nil { + if *mfd.ValueName != *fd.ValueName { + return fmt.Errorf("flag '%s' has incompatible value-name '%v' - must be '%v'", fd.Name, *fd.ValueName, *mfd.ValueName) + } + } + + if mfd.Description == nil { + if fd.Description != nil { + mfd.Description = fd.Description + } + } else if fd.Description != nil { + if *mfd.Description != *fd.Description { + return fmt.Errorf("flag '%s' has incompatible description", fd.Name) + } + } + + if mfd.Required == nil { + if fd.Required != nil { + mfd.Required = fd.Required + } + } else if *mfd.Required { + if fd.Required != nil && !*fd.Required { + return fmt.Errorf("flag '%s' is incompatibly optional - must be required", fd.Name) + } + } + + if fd.DefaultValue != mfd.DefaultValue { + return fmt.Errorf("flag '%s' has incompatible default value '%s' - must be '%s'", fd.Name, fd.DefaultValue, mfd.DefaultValue) + } + + mfd.flagDefs = append(mfd.flagDefs, fd) + return nil +} + +func (mfd *mergedFlagDef) setValue(v string) error { + mfd.applied = true + for _, fd := range mfd.flagDefs { + if err := fd.setValue(v); err != nil { + return err + } + } + return nil +} + +func (mfd *mergedFlagDef) isRequired() bool { + return mfd.Required != nil && *mfd.Required +} + +func (mfd *mergedFlagDef) isMissing() bool { + return mfd.isRequired() && !mfd.applied +} + +func (mfd *mergedFlagDef) getValueName() string { + if mfd.HasValue { + if mfd.ValueName != nil { + return *mfd.ValueName + } else { + return "VALUE" + } + } else { + return "" + } +} diff --git a/flag_merged_test.go b/flag_merged_test.go new file mode 100644 index 0000000..2f78084 --- /dev/null +++ b/flag_merged_test.go @@ -0,0 +1,263 @@ +package command + +import ( + "reflect" + "slices" + "testing" + + . "github.com/arikkfir/justest" +) + +func TestMergedFlagDefAddFlagDef(t *testing.T) { + t.Parallel() + + type testCase struct { + mfd *mergedFlagDef + fd *flagDef + expectedError string + verifier func(t T, tc *testCase) + } + + testCases := map[string]testCase{ + "valid": { + mfd: &mergedFlagDef{ + flagInfo: flagInfo{ + Name: "my-flag", + EnvVarName: ptrOf("MY_FLAG"), + HasValue: true, + ValueName: &[]string{"VVV"}[0], + Description: &[]string{"This is the description"}[0], + Required: &[]bool{true}[0], + DefaultValue: "abc", + }, + }, + fd: &flagDef{ + flagInfo: flagInfo{ + Name: "my-flag", + EnvVarName: ptrOf("MY_FLAG"), + HasValue: true, + ValueName: &[]string{"VVV"}[0], + Description: &[]string{"This is the description"}[0], + Required: &[]bool{true}[0], + DefaultValue: "abc", + }, + }, + verifier: func(t T, tc *testCase) { + With(t).Verify(tc.mfd.Name).Will(EqualTo(tc.fd.Name)).OrFail() + With(t).Verify(tc.mfd.EnvVarName).Will(EqualTo(tc.fd.EnvVarName)).OrFail() + With(t).Verify(tc.mfd.HasValue).Will(EqualTo(tc.fd.HasValue)).OrFail() + With(t).Verify(tc.mfd.ValueName).Will(EqualTo(tc.fd.ValueName)).OrFail() + With(t).Verify(tc.mfd.Description).Will(EqualTo(tc.fd.Description)).OrFail() + With(t).Verify(tc.mfd.Required).Will(EqualTo(tc.fd.Required)).OrFail() + With(t).Verify(tc.mfd.DefaultValue).Will(EqualTo(tc.fd.DefaultValue)).OrFail() + }, + }, + "unexpected name": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag"}}, + fd: &flagDef{flagInfo: flagInfo{Name: "other-flag"}}, + expectedError: `given flag 'other-flag' has incompatible name - must be 'my-flag'`, + }, + "unexpected environment variable": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", EnvVarName: ptrOf("MY_FLAG")}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", EnvVarName: ptrOf("BAD_FLAG")}}, + expectedError: `flag 'my-flag' has incompatible environment variable name 'BAD_FLAG' - must be 'MY_FLAG'`, + }, + "expected flag to have a value": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: true}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: false}}, + expectedError: `given flag 'my-flag' must have a value, but it does not`, + }, + "expected flag to not have a value": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: false}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: true}}, + expectedError: `given flag 'my-flag' must not have a value, but it does`, + }, + "given value-name overrides nil value name": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag"}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", ValueName: &[]string{"val"}[0]}}, + verifier: func(t T, tc *testCase) { + With(t).Verify(tc.mfd.ValueName).Will(EqualTo(tc.fd.ValueName)).OrFail() + }, + }, + "given value-name equals existing value name does nothing": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", ValueName: &[]string{"val"}[0]}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", ValueName: &[]string{"val"}[0]}}, + verifier: func(t T, tc *testCase) { + With(t).Verify(tc.mfd.ValueName).Will(EqualTo(tc.fd.ValueName)).OrFail() + }, + }, + "unexpected value name": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", ValueName: &[]string{"val1"}[0]}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", ValueName: &[]string{"val2"}[0]}}, + expectedError: `flag 'my-flag' has incompatible value-name 'val2' - must be 'val1'`, + }, + "given description overrides nil description": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag"}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Description: &[]string{"desc"}[0]}}, + verifier: func(t T, tc *testCase) { + With(t).Verify(tc.mfd.Description).Will(EqualTo(tc.fd.Description)).OrFail() + }, + }, + "given description equals existing description does nothing": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Description: &[]string{"desc"}[0]}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Description: &[]string{"desc"}[0]}}, + verifier: func(t T, tc *testCase) { + With(t).Verify(tc.mfd.Description).Will(EqualTo(tc.fd.Description)).OrFail() + }, + }, + "unexpected description": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Description: &[]string{"desc1"}[0]}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Description: &[]string{"desc2"}[0]}}, + expectedError: `flag 'my-flag' has incompatible description`, + }, + "given required overrides nil required": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag"}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{true}[0]}}, + verifier: func(t T, tc *testCase) { + With(t).Verify(tc.mfd.Required).Will(EqualTo(tc.fd.Required)).OrFail() + }, + }, + "given required equals existing required does nothing": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{true}[0]}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{true}[0]}}, + verifier: func(t T, tc *testCase) { + With(t).Verify(tc.mfd.Required).Will(EqualTo(tc.fd.Required)).OrFail() + }, + }, + "unexpected required": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{true}[0]}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{false}[0]}}, + expectedError: `flag 'my-flag' is incompatibly optional - must be required`, + }, + "unexpected default value": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", DefaultValue: "abc"}}, + fd: &flagDef{flagInfo: flagInfo{Name: "my-flag", DefaultValue: "abcdef"}}, + expectedError: `flag 'my-flag' has incompatible default value 'abcdef' - must be 'abc'`, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + if tc.expectedError != "" { + With(t).Verify(tc.mfd.addFlagDef(tc.fd)).Will(Fail(tc.expectedError)).OrFail() + With(t).Verify(slices.Contains(tc.mfd.flagDefs, tc.fd)).Will(EqualTo(false)).OrFail() + } else { + With(t).Verify(tc.mfd.addFlagDef(tc.fd)).Will(Succeed()).OrFail() + With(t).Verify(slices.Contains(tc.mfd.flagDefs, tc.fd)).Will(EqualTo(true)).OrFail() + } + if tc.verifier != nil { + tc.verifier(t, &tc) + } + }) + } +} + +func TestMergedFlagDefSetValue(t *testing.T) { + t.Parallel() + + targets := [3]string{} + mfd := &mergedFlagDef{ + flagInfo: flagInfo{ + Name: "my-flag", + HasValue: true, + }, + flagDefs: []*flagDef{ + {flagInfo: flagInfo{Name: "my-flag", HasValue: true}, Targets: []reflect.Value{reflect.ValueOf(&targets).Elem().Index(0)}}, + {flagInfo: flagInfo{Name: "my-flag", HasValue: true}, Targets: []reflect.Value{reflect.ValueOf(&targets).Elem().Index(1)}}, + {flagInfo: flagInfo{Name: "my-flag", HasValue: true}, Targets: []reflect.Value{reflect.ValueOf(&targets).Elem().Index(2)}}, + }, + } + + With(t).Verify(mfd.setValue("v1")).Will(Succeed()).OrFail() + With(t).Verify(targets).Will(EqualTo([3]string{"v1", "v1", "v1"})).OrFail() +} + +func TestMergedFlagDefIsRequired(t *testing.T) { + t.Parallel() + + type testCase struct { + mfd *mergedFlagDef + expectedRequired bool + } + + testCases := map[string]testCase{ + "nil": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag"}}, + expectedRequired: false, + }, + "*true": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{true}[0]}}, + expectedRequired: true, + }, + "*false": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{false}[0]}}, + expectedRequired: false, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + With(t).Verify(tc.mfd.isRequired()).Will(EqualTo(tc.expectedRequired)).OrFail() + }) + } +} + +func TestMergedFlagDefIsMissing(t *testing.T) { + t.Parallel() + + type testCase struct { + mfd *mergedFlagDef + expectedMissing bool + } + + testCases := map[string]testCase{ + "required & not applied": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{true}[0]}, applied: false}, + expectedMissing: true, + }, + "not required & not applied": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", Required: &[]bool{false}[0]}, applied: false}, + expectedMissing: false, + }, + "implicitly not required & not applied": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag"}, applied: false}, + expectedMissing: false, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + With(t).Verify(tc.mfd.isMissing()).Will(EqualTo(tc.expectedMissing)).OrFail() + }) + } +} + +func TestMergedFlagDefGetValueName(t *testing.T) { + t.Parallel() + + type testCase struct { + mfd *mergedFlagDef + expectedValueName string + } + + testCases := map[string]testCase{ + "does not have value": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: false}}, + expectedValueName: "", + }, + "has value & has value name": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: true, ValueName: &[]string{"VVV"}[0]}}, + expectedValueName: "VVV", + }, + "has value & has no value name": { + mfd: &mergedFlagDef{flagInfo: flagInfo{Name: "my-flag", HasValue: true}}, + expectedValueName: "VALUE", + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + With(t).Verify(tc.mfd.getValueName()).Will(EqualTo(tc.expectedValueName)).OrFail() + }) + } +} diff --git a/flag_set.go b/flag_set.go new file mode 100644 index 0000000..c04da1d --- /dev/null +++ b/flag_set.go @@ -0,0 +1,509 @@ +package command + +import ( + "cmp" + "errors" + "flag" + "fmt" + "io" + "reflect" + "regexp" + "sort" + "strconv" + "strings" +) + +type Tag string + +const ( + TagFlag Tag = "flag" + TagName Tag = "name" + TagEnv Tag = "env" + TagValueName Tag = "value-name" + TagDescription Tag = "desc" + TagRequired Tag = "required" + TagInherited Tag = "inherited" + TagArgs Tag = "args" +) + +type ErrInvalidTag struct { + Cause error + Tag Tag + Value string +} + +func (e *ErrInvalidTag) Error() string { + return fmt.Sprintf("invalid tag '%s=%s': %s", e.Tag, e.Value, e.Cause) +} + +func (e *ErrInvalidTag) Unwrap() error { + return e.Cause +} + +type ErrUnknownFlag struct { + Cause error + Flag string +} + +func (e *ErrUnknownFlag) Error() string { + return fmt.Sprintf("unknown flag: --%s", e.Flag) +} + +func (e *ErrUnknownFlag) Unwrap() error { + return e.Cause +} + +type ErrRequiredFlagMissing struct { + Cause error + Flag string +} + +func (e *ErrRequiredFlagMissing) Error() string { + return fmt.Sprintf("required flag is missing: --%s", e.Flag) +} + +func (e *ErrRequiredFlagMissing) Unwrap() error { + return e.Cause +} + +type flagSet struct { + flags []*flagDef + parent *flagSet + positionalsTargets []*[]string +} + +func newFlagSet(parent *flagSet, valueOfConfig reflect.Value) (*flagSet, error) { + fs := &flagSet{parent: parent} + if valueOfConfig.Kind() == reflect.Ptr && valueOfConfig.Type().Elem().Kind() == reflect.Struct { + if valueOfConfig.IsNil() { + valueOfConfig.Set(reflect.New(valueOfConfig.Type().Elem())) + } + if err := fs.readFlagsFromStruct(valueOfConfig.Elem(), false); err != nil { + return nil, err + } + } + return fs, nil +} + +func (fs *flagSet) hasFlags() bool { + if len(fs.flags) > 0 { + return true + } + for _fs := fs.parent; _fs != nil; _fs = _fs.parent { + for _, fd := range _fs.flags { + if fd.Inherited { + return true + } + } + } + return false +} + +func (fs *flagSet) readFlagsFromStruct(s reflect.Value, defaultInherited bool) error { + for i := 0; i < s.NumField(); i++ { + fieldValue := s.Field(i) + structField := s.Type().Field(i) + fieldName := structField.Name + if err := fs.readFlagFromField(fieldValue, structField, defaultInherited); err != nil { + return fmt.Errorf("invalid field '%s.%s': %w", s.Type(), fieldName, err) + } + } + return nil +} + +func (fs *flagSet) readFlagFromField(fieldValue reflect.Value, structField reflect.StructField, defaultInherited bool) error { + fieldName := structField.Name + + // Initial configuration of this field + var args bool + var flagTag Tag + fd := &flagDef{ + flagInfo: flagInfo{Name: fieldNameToFlagName(fieldName)}, + Inherited: defaultInherited, + Targets: []reflect.Value{fieldValue}, + } + + // Read field tags + if tag, ok := structField.Tag.Lookup(string(TagFlag)); ok { + if v, err := strconv.ParseBool(tag); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + err = ne.Err + } + return &ErrInvalidTag{Cause: err, Tag: TagFlag, Value: tag} + } else if !v { + return nil + } else { + flagTag = TagFlag + } + } + if tag, ok := structField.Tag.Lookup(string(TagName)); ok { + if tag == "" { + return &ErrInvalidTag{Cause: fmt.Errorf("must not be empty"), Tag: TagName, Value: tag} + } + flagTag = TagName + fd.flagInfo.Name = tag + } + if tag, ok := structField.Tag.Lookup(string(TagEnv)); ok { + if tag == "" { + return &ErrInvalidTag{Cause: fmt.Errorf("must not be empty"), Tag: TagEnv, Value: tag} + } else { + tag = strings.ToUpper(tag) + } + flagTag = TagEnv + fd.flagInfo.EnvVarName = &tag + } + if tag, ok := structField.Tag.Lookup(string(TagValueName)); ok { + if tag == "" { + return &ErrInvalidTag{Cause: fmt.Errorf("must not be empty"), Tag: TagValueName, Value: tag} + } else if fieldValue.Kind() == reflect.Bool { + return &ErrInvalidTag{Cause: fmt.Errorf("not supported for bool fields"), Tag: TagValueName, Value: tag} + } + flagTag = TagValueName + fd.flagInfo.ValueName = &tag + } + if tag, ok := structField.Tag.Lookup(string(TagDescription)); ok { + flagTag = TagDescription + fd.flagInfo.Description = &tag + } + if tag, ok := structField.Tag.Lookup(string(TagRequired)); ok { + if v, err := strconv.ParseBool(tag); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + err = ne.Err + } + return &ErrInvalidTag{Cause: err, Tag: TagRequired, Value: tag} + } else { + flagTag = TagRequired + fd.flagInfo.Required = ptrOf(v) + } + } + if tag, ok := structField.Tag.Lookup(string(TagInherited)); ok { + if v, err := strconv.ParseBool(tag); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + err = ne.Err + } + return &ErrInvalidTag{Cause: err, Tag: TagInherited, Value: tag} + } else { + flagTag = TagInherited + fd.Inherited = v + } + } + if tag, ok := structField.Tag.Lookup(string(TagArgs)); ok { + if v, err := strconv.ParseBool(tag); err != nil { + var ne *strconv.NumError + if errors.As(err, &ne) { + err = ne.Err + } + return &ErrInvalidTag{Cause: err, Tag: TagArgs, Value: tag} + } else { + args = v + } + } + + if fieldValue.Kind() == reflect.Struct { + // Struct fields are only containers for other fields; if the struct is tagged with "args" or any flag tag, fail + if args { + return &ErrInvalidTag{Cause: fmt.Errorf("cannot be used on struct fields"), Tag: TagArgs, Value: strconv.FormatBool(args)} + } else if flagTag != "" { + return &ErrInvalidTag{Cause: fmt.Errorf("cannot be used on struct fields"), Tag: flagTag, Value: structField.Tag.Get(string(flagTag))} + } else if err := fs.readFlagsFromStruct(fieldValue, fd.Inherited); err != nil { + return err + } else { + return nil + } + } else if !args && flagTag == "" { + // Neither a positional args target nor a flag - do nothing and exit + return nil + } else if !fieldValue.CanAddr() { + // Field must be addressable or we will not be able to update it with CLI arguments + return fmt.Errorf("not addressable") + } else if !fieldValue.CanSet() { + // Field must be settable or we will not be able to update it with CLI arguments + return fmt.Errorf("not settable") + } else if args { + // If field is tagged with "args", it cannot also serve as a flag; it also must be of type "[]string" + if flagTag != "" { + return &ErrInvalidTag{Cause: fmt.Errorf("cannot be a flag as well"), Tag: TagArgs, Value: strconv.FormatBool(args)} + } else if structField.Type.ConvertibleTo(reflect.TypeOf([]string{})) { + fs.positionalsTargets = append(fs.positionalsTargets, fieldValue.Addr().Interface().(*[]string)) + return nil + } else { + return &ErrInvalidTag{Cause: fmt.Errorf("must be typed as []string"), Tag: TagArgs, Value: strconv.FormatBool(args)} + } + } + + // Configure whether flag should be given a value in the CLI, and the default value if one is not provided + switch fieldValue.Kind() { + case reflect.Bool: + fd.HasValue = false + fd.DefaultValue = "false" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fd.HasValue = true + fd.DefaultValue = strconv.FormatInt(fieldValue.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + fd.HasValue = true + fd.DefaultValue = strconv.FormatUint(fieldValue.Uint(), 10) + case reflect.Float32, reflect.Float64: + fd.HasValue = true + fd.DefaultValue = strconv.FormatFloat(fieldValue.Float(), 'g', -1, 64) + case reflect.String: + fd.HasValue = true + fd.DefaultValue = fieldValue.String() + default: + // Unsupported flag field type + return fmt.Errorf("unsupported field type: %s", fieldValue.Kind()) + } + + // Otherwise, this is a flag - check if it has already been registered? + for _, fdi := range fs.flags { + if fdi.Name == fd.Name { + if fdi.EnvVarName == nil { + fdi.EnvVarName = fd.EnvVarName + } else if fd.EnvVarName != nil && *fdi.EnvVarName != *fd.EnvVarName { + return &ErrInvalidTag{Cause: fmt.Errorf("cannot redefine environment variable name"), Tag: TagEnv, Value: *fd.EnvVarName} + } + if fdi.HasValue != fd.HasValue { + return fmt.Errorf("incompatible field types detected (is one a bool and another isn't?)") + } + if fdi.ValueName == nil { + fdi.ValueName = fd.ValueName + } else if fd.ValueName != nil && *fdi.ValueName != *fd.ValueName { + return &ErrInvalidTag{Cause: fmt.Errorf("cannot redefine value name"), Tag: TagValueName, Value: *fd.ValueName} + } + if fdi.Description == nil { + fdi.Description = fd.Description + } else if fd.Description != nil && *fdi.Description != *fd.Description { + return &ErrInvalidTag{Cause: fmt.Errorf("cannot redefine description"), Tag: TagDescription, Value: *fd.Description} + } + if fdi.Required == nil { + fdi.Required = fd.Required + } else if fd.Required != nil && *fdi.Required != *fd.Required { + return &ErrInvalidTag{Cause: fmt.Errorf("cannot redefine required status"), Tag: TagRequired, Value: strconv.FormatBool(*fd.Required)} + } + if fdi.DefaultValue != fd.DefaultValue { + return fmt.Errorf("incompatible default values detected: '%s' vs '%s'", fdi.DefaultValue, fd.DefaultValue) + } + if fdi.Inherited != fd.Inherited { + return fmt.Errorf("incompatible inherited status detected: '%v' vs '%v'", fdi.Inherited, fd.Inherited) + } + fdi.Targets = append(fdi.Targets, fd.Targets...) + return nil + } + } + + // New flag, add it as is + fs.flags = append(fs.flags, fd) + return nil +} + +func (fs *flagSet) getMergedFlagDefs() ([]*mergedFlagDef, error) { + flags := make(map[string]*mergedFlagDef) + for cfs := fs; cfs != nil; cfs = cfs.parent { + for _, fd := range cfs.flags { + if cfs == fs || fd.Inherited { + if mfd, ok := flags[fd.Name]; !ok { + flags[fd.Name] = &mergedFlagDef{ + flagInfo: flagInfo{ + Name: fd.Name, + EnvVarName: fd.EnvVarName, + HasValue: fd.HasValue, + ValueName: fd.ValueName, + Description: fd.Description, + Required: fd.Required, + DefaultValue: fd.DefaultValue, + }, + applied: false, + flagDefs: []*flagDef{fd}, + } + } else if err := mfd.addFlagDef(fd); err != nil { + return nil, err + } + } + } + } + var mergedFlagDefs []*mergedFlagDef + for _, mfd := range flags { + if mfd.EnvVarName == nil { + mfd.EnvVarName = ptrOf(flagNameToEnvVarName(mfd.Name)) + } + if mfd.ValueName == nil { + mfd.ValueName = ptrOf("VALUE") + } + if mfd.Required == nil { + mfd.Required = ptrOf(false) + } + sort.Slice(mfd.flagDefs, func(ai, bi int) bool { return mfd.flagDefs[ai].isLessThan(mfd.flagDefs[bi]) }) + mergedFlagDefs = append(mergedFlagDefs, mfd) + } + sort.Slice(mergedFlagDefs, func(ai, bi int) bool { return cmp.Less(mergedFlagDefs[ai].Name, mergedFlagDefs[bi].Name) }) + return mergedFlagDefs, nil +} + +func (fs *flagSet) apply(envVars map[string]string, args []string) error { + if args == nil { + args = []string{} + } + if envVars == nil { + envVars = make(map[string]string) + } + + stdFs := flag.NewFlagSet("", flag.ContinueOnError) + stdFs.SetOutput(io.Discard) + + // Merge flags from this flag set and its parents + mergedFlagDefs, err := fs.getMergedFlagDefs() + if err != nil { + return err + } + + // Iterate flags and define them in the stdlib FlagSet + for _, mfd := range mergedFlagDefs { + + // Set the value to the flag's corresponding environment variable, if one was given + if v, found := envVars[*mfd.EnvVarName]; found { + if err := mfd.setValue(v); err != nil { + return err + } + } + + // By definition, for the same name - all flags have the same "HasValue" value, so it should be safe to just + // take it from the first one + if mfd.HasValue { + stdFs.Func(mfd.Name, "", func(v string) error { return mfd.setValue(v) }) + } else { + stdFs.BoolFunc(mfd.Name, "", func(string) error { return mfd.setValue("true") }) + } + } + + // Parse the given arguments, which will result in all CLI flags being set + if err := stdFs.Parse(args); err != nil { + re := regexp.MustCompile(`^flag provided but not defined: -(.+)$`) + if matches := re.FindStringSubmatch(err.Error()); matches != nil { + return &ErrUnknownFlag{Cause: err, Flag: matches[1]} + } + return err + } + + // Verify all required flags have been set + for _, mfd := range mergedFlagDefs { + if mfd.isMissing() { + return &ErrRequiredFlagMissing{Cause: err, Flag: mfd.Name} + } + } + + // Apply positionals + positionals := stdFs.Args() + for cfs := fs; cfs != nil; cfs = cfs.parent { + for _, target := range cfs.positionalsTargets { + *target = positionals + } + } + return nil +} + +func (fs *flagSet) printFlagsSingleLine(b io.Writer) error { + + // Merge flags from this flag set and its parents + mergedFlagDefs, err := fs.getMergedFlagDefs() + if err != nil { + return err + } + + space := false + for _, fd := range mergedFlagDefs { + if space { + _, _ = fmt.Fprint(b, " ") + } else { + space = true + } + if !fd.isRequired() { + _, _ = fmt.Fprint(b, "[") + } + + valueName := fd.getValueName() + if valueName != "" { + _, _ = fmt.Fprintf(b, "--%s=%s", fd.Name, valueName) + } else { + _, _ = fmt.Fprintf(b, "--%s", fd.Name) + } + if !fd.isRequired() { + _, _ = fmt.Fprint(b, "]") + } + } + if len(fs.positionalsTargets) > 0 { + if space { + _, _ = fmt.Fprint(b, " ") + } + _, _ = fmt.Fprint(b, "[ARGS...]") + } + + return nil +} + +func (fs *flagSet) printFlagsMultiLine(ww *WrappingWriter, basePrefix string) error { + + // Merge flags from this flag set and its parents + mergedFlagDefs, err := fs.getMergedFlagDefs() + if err != nil { + return err + } + + flagsColWidth := 0 + fullFlagNames := make(map[string]string) + for _, fd := range mergedFlagDefs { + var fullFlagName string + valueName := fd.getValueName() + if valueName != "" { + fullFlagName = fmt.Sprintf("--%s=%s", fd.Name, valueName) + } else { + fullFlagName = fmt.Sprintf("--%s", fd.Name) + } + if fd.Required == nil || !*fd.Required { + fullFlagName = "[" + fullFlagName + "]" + } + fullFlagNames[fd.Name] = fullFlagName + if len(fullFlagName) > flagsColWidth { + flagsColWidth = len(fullFlagName) + } + } + + descriptionStartColumn := flagsColWidth + (10 - flagsColWidth%10) + for _, fd := range mergedFlagDefs { + flagName := fullFlagNames[fd.Name] + _, _ = fmt.Fprint(ww, flagName) + _, _ = fmt.Fprint(ww, strings.Repeat(" ", descriptionStartColumn-len(flagName))) + _ = ww.SetLinePrefix(basePrefix + strings.Repeat(" ", descriptionStartColumn)) + + // Build flag description + hasDescription := fd.Description != nil && *fd.Description != "" + var sep string + if hasDescription { + _, _ = fmt.Fprint(ww, *fd.Description) + sep = " (" + } + + if fd.DefaultValue != "" { + if sep != "" { + _, _ = fmt.Fprint(ww, sep) + } + _, _ = fmt.Fprintf(ww, "default value: %s", fd.DefaultValue) + sep = ", " + } + if fd.EnvVarName != nil { + if sep != "" { + _, _ = fmt.Fprint(ww, sep) + } + _, _ = fmt.Fprintf(ww, "environment variable: %s", *fd.EnvVarName) + } + if hasDescription { + _, _ = fmt.Fprint(ww, ")") + } + + _ = ww.SetLinePrefix(basePrefix) + _, _ = fmt.Fprintln(ww) + } + + return nil +} diff --git a/flag_set_test.go b/flag_set_test.go new file mode 100644 index 0000000..b69c41f --- /dev/null +++ b/flag_set_test.go @@ -0,0 +1,1007 @@ +package command + +import ( + "bytes" + stdcmp "cmp" + "reflect" + "testing" + + . "github.com/arikkfir/justest" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestNewFlagSet(t *testing.T) { + t.Parallel() + type testCase struct { + config any + expectedError string + expectedFlags func(tc *testCase) []*flagDef + expectedPositionalsTargets func(tc *testCase) []*[]string + } + testCases := map[string]testCase{ + "nil config": {}, + "config wih no flags": {config: &struct{}{}}, + "config with ignored flags": {config: &struct{ MyField string }{MyField: "abc"}}, + "config with a single flag": { + config: &struct { + MyField string `name:"my-field" env:"MY_FIELD" value-name:"VVV" desc:"desc" required:"true" inherited:"true" args:"false"` + }{MyField: "abc"}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + EnvVarName: ptrOf("MY_FIELD"), + HasValue: true, + ValueName: ptrOf("VVV"), + Description: ptrOf("desc"), + Required: ptrOf(true), + DefaultValue: "abc", + }, + Inherited: true, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "non struct pointer config is ignored": { + config: struct { + MyField string `name:"my-field" env:"MY_FIELD" value-name:"VVV" desc:"desc" required:"true" inherited:"true" args:"false"` + }{MyField: "abc"}, + }, + "config with multiple flags": { + config: &struct { + MyField1 string `name:"my-field1" env:"MY_FIELD1" value-name:"V1" desc:"desc1" required:"true" inherited:"true" args:"false"` + MyField2 string `name:"my-field2" env:"MY_FIELD2" value-name:"V2" desc:"desc2" required:"false" inherited:"false" args:"false"` + }{MyField1: "abc1", MyField2: "abc2"}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field1", + EnvVarName: ptrOf("MY_FIELD1"), + HasValue: true, + ValueName: ptrOf("V1"), + Description: ptrOf("desc1"), + Required: ptrOf(true), + DefaultValue: "abc1", + }, + Inherited: true, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField1")}, + }, + { + flagInfo: flagInfo{ + Name: "my-field2", + EnvVarName: ptrOf("MY_FIELD2"), + HasValue: true, + ValueName: ptrOf("V2"), + Description: ptrOf("desc2"), + Required: ptrOf(false), + DefaultValue: "abc2", + }, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField2")}, + }, + } + }, + }, + "bad 'flag' tag": { + config: &struct { + MyField string `flag:"bad-value"` + }{}, + expectedError: `^invalid field 'struct \{ MyField string "flag:\\"bad-value\\"" \}.MyField': invalid tag 'flag=bad-value': invalid syntax$`, + }, + "field with 'flag=false' tag is ignored": { + config: &struct { + MyField string `flag:"false"` + }{}, + }, + "field with just 'flag=true' tag is picked up": { + config: &struct { + MyField string `flag:"true"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + HasValue: true, + }, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "field with empty 'name' tag is rejected": { + config: &struct { + MyField string `name:""` + }{}, + expectedError: `^invalid field 'struct \{ MyField string "name:\\"\\"" \}.MyField': invalid tag 'name=': must not be empty$`, + }, + "value of 'name' tag is used": { + config: &struct { + MyField string `name:"a"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "a", HasValue: true}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "field with empty 'env' tag is rejected": { + config: &struct { + MyField string `env:""` + }{}, + expectedError: `^invalid field 'struct \{ MyField string "env:\\"\\"" \}.MyField': invalid tag 'env=': must not be empty$`, + }, + "value of 'env' tag is used and uppercased": { + config: &struct { + MyField string `env:"a"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", EnvVarName: ptrOf("A"), HasValue: true}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "field with empty 'value-name' tag is rejected": { + config: &struct { + MyField string `value-name:""` + }{}, + expectedError: `^invalid field 'struct \{ MyField string "value-name:\\"\\"" \}.MyField': invalid tag 'value-name=': must not be empty$`, + }, + "value of 'value-name' tag is used": { + config: &struct { + MyField string `value-name:"V"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true, ValueName: ptrOf("V")}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "field with empty 'description' tag is allowed": { + config: &struct { + MyField string `desc:""` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true, Description: ptrOf("")}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "value of 'description' tag is used": { + config: &struct { + MyField string `desc:"Some Description"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true, Description: ptrOf("Some Description")}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "bad 'required' tag": { + config: &struct { + MyField string `required:"bad-value"` + }{}, + expectedError: `^invalid field 'struct \{ MyField string "required:\\"bad-value\\"" \}.MyField': invalid tag 'required=bad-value': invalid syntax$`, + }, + "field with 'required=false' tag is not required": { + config: &struct { + MyField string `required:"false"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true, Required: ptrOf(false)}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "field with 'required=true' tag is required": { + config: &struct { + MyField string `required:"true"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true, Required: ptrOf(true)}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "bad 'inherited' tag": { + config: &struct { + MyField string `inherited:"bad-value"` + }{}, + expectedError: `^invalid field 'struct \{ MyField string "inherited:\\"bad-value\\"" \}.MyField': invalid tag 'inherited=bad-value': invalid syntax$`, + }, + "field with 'inherited=false' tag is not inherited": { + config: &struct { + MyField string `inherited:"false"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true}, + Inherited: false, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "field with 'inherited=true' tag is inherited": { + config: &struct { + MyField string `inherited:"true"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true}, + Inherited: true, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "bad 'args' tag": { + config: &struct { + MyField string `args:"bad-value"` + }{}, + expectedError: `^invalid field 'struct \{ MyField string "args:\\"bad-value\\"" \}.MyField': invalid tag 'args=bad-value': invalid syntax$`, + }, + "field with 'args=false' tag is not marked as args": { + config: &struct { + MyField string `name:"f" args:"false"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "f", HasValue: true}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "field with 'args=true' tag is marked as args": { + config: &struct { + MyField []string `args:"true"` + }{}, + expectedPositionalsTargets: func(tc *testCase) []*[]string { + typedVal := reflect.ValueOf(tc.config).Elem().FieldByName("MyField").Interface().([]string) + return []*[]string{&typedVal} + }, + }, + "field with 'name' and 'args' tags is rejected": { + config: &struct { + MyField []string `name:"f" args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField \[\]string "name:\\"f\\" args:\\"true\\"" \}.MyField': invalid tag 'args=true': cannot be a flag as well$`, + }, + "field with 'env' and 'args' tags is rejected": { + config: &struct { + MyField []string `env:"f" args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField \[\]string "env:\\"f\\" args:\\"true\\"" \}.MyField': invalid tag 'args=true': cannot be a flag as well$`, + }, + "field with 'value-name' and 'args' tags is rejected": { + config: &struct { + MyField []string `value-name:"f" args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField \[\]string "value-name:\\"f\\" args:\\"true\\"" \}.MyField': invalid tag 'args=true': cannot be a flag as well$`, + }, + "field with 'desc' and 'args' tags is rejected": { + config: &struct { + MyField []string `desc:"f" args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField \[\]string "desc:\\"f\\" args:\\"true\\"" \}.MyField': invalid tag 'args=true': cannot be a flag as well$`, + }, + "field with 'required' and 'args' tags is rejected": { + config: &struct { + MyField []string `required:"true" args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField \[\]string "required:\\"true\\" args:\\"true\\"" \}.MyField': invalid tag 'args=true': cannot be a flag as well$`, + }, + "field with 'inherited' and 'args' tags is rejected": { + config: &struct { + MyField []string `inherited:"true" args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField \[\]string "inherited:\\"true\\" args:\\"true\\"" \}.MyField': invalid tag 'args=true': cannot be a flag as well$`, + }, + "field with 'args' of incorrect type is rejected": { + config: &struct { + MyField int `args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField int "args:\\"true\\"" \}.MyField': invalid tag 'args=true': must be typed as \[\]string$`, + }, + "struct field cannot use 'args' tag": { + config: &struct { + MyField struct{} `args:"true"` + }{}, + expectedError: `^invalid field 'struct \{ MyField struct \{\} "args:\\"true\\"" \}.MyField': invalid tag 'args=true': cannot be used on struct fields$`, + }, + "flag name is inferred from field name": { + config: &struct { + MyField int `flag:"true"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{Name: "my-field", HasValue: true, DefaultValue: "0"}, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyField")}, + }, + } + }, + }, + "tag 'value-name' is not allowed for bool fields": { + config: &struct { + MyField bool `value-name:"VAL"` + }{}, + expectedError: `^invalid field 'struct \{ MyField bool "value-name:\\"VAL\\"" \}.MyField': invalid tag 'value-name=VAL': not supported for bool fields$`, + }, + "nested config": { + config: &struct { + OuterField1 string `name:"outer-field1" env:"OUTER_FIELD1" value-name:"outer-V1" desc:"outer-desc1" required:"true" inherited:"true"` + OuterField2 string `name:"outer-field2" env:"OUTER_FIELD2" value-name:"outer-V2" desc:"outer-desc2" required:"false" inherited:"false"` + OuterArgs []string `args:"true"` + MyStruct struct { + InnerField1 string `name:"inner-field1" env:"INNER_FIELD1" value-name:"inner-V1" desc:"inner-desc1" required:"true" inherited:"true"` + InnerField2 string `name:"inner-field2" env:"INNER_FIELD2" value-name:"inner-V2" desc:"inner-desc2" required:"false" inherited:"false"` + InnerArgs []string `args:"true"` + } + }{ + OuterField1: "out1", + OuterField2: "out2", + MyStruct: struct { + InnerField1 string `name:"inner-field1" env:"INNER_FIELD1" value-name:"inner-V1" desc:"inner-desc1" required:"true" inherited:"true"` + InnerField2 string `name:"inner-field2" env:"INNER_FIELD2" value-name:"inner-V2" desc:"inner-desc2" required:"false" inherited:"false"` + InnerArgs []string `args:"true"` + }{ + InnerField1: "inner1", + InnerField2: "inner2", + }, + }, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "outer-field1", + EnvVarName: ptrOf("OUTER_FIELD1"), + HasValue: true, + ValueName: ptrOf("outer-V1"), + Description: ptrOf("outer-desc1"), + Required: ptrOf(true), + DefaultValue: "out1", + }, + Inherited: true, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("OuterField1")}, + }, + { + flagInfo: flagInfo{ + Name: "outer-field2", + EnvVarName: ptrOf("OUTER_FIELD2"), + HasValue: true, + ValueName: ptrOf("outer-V2"), + Description: ptrOf("outer-desc2"), + Required: ptrOf(false), + DefaultValue: "out2", + }, + Inherited: false, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("OuterField2")}, + }, + { + flagInfo: flagInfo{ + Name: "inner-field1", + EnvVarName: ptrOf("INNER_FIELD1"), + HasValue: true, + ValueName: ptrOf("inner-V1"), + Description: ptrOf("inner-desc1"), + Required: ptrOf(true), + DefaultValue: "inner1", + }, + Inherited: true, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyStruct").FieldByName("InnerField1")}, + }, + { + flagInfo: flagInfo{ + Name: "inner-field2", + EnvVarName: ptrOf("INNER_FIELD2"), + HasValue: true, + ValueName: ptrOf("inner-V2"), + Description: ptrOf("inner-desc2"), + Required: ptrOf(false), + DefaultValue: "inner2", + }, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("MyStruct").FieldByName("InnerField2")}, + }, + } + }, + expectedPositionalsTargets: func(tc *testCase) []*[]string { + valueOfOuterArgs := reflect.ValueOf(tc.config).Elem().FieldByName("OuterArgs").Interface().([]string) + valueOfInnerArgs := reflect.ValueOf(tc.config).Elem().FieldByName("MyStruct").FieldByName("InnerArgs").Interface().([]string) + return []*[]string{&valueOfOuterArgs, &valueOfInnerArgs} + }, + }, + "redeclared field cannot change environment variable": { + config: &struct { + MyField1 string `name:"my-field1" env:"MY_FIELD1"` + MyField2 string `name:"my-field1" env:"MY_FIELD2"` + }{}, + expectedError: `^invalid field 'struct \{ MyField1 string "name:\\"my-field1\\" env:\\"MY_FIELD1\\""; MyField2 string "name:\\"my-field1\\" env:\\"MY_FIELD2\\"" }.MyField2': invalid tag 'env=MY_FIELD2': cannot redefine environment variable name$`, + }, + "redeclared field can set environment variable": { + config: &struct { + MyField1 string `name:"my-field"` + MyField2 string `name:"my-field" env:"MF"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + EnvVarName: ptrOf("MF"), + HasValue: true, + }, + Targets: []reflect.Value{ + reflect.ValueOf(tc.config).Elem().FieldByName("MyField1"), + reflect.ValueOf(tc.config).Elem().FieldByName("MyField2"), + }, + }, + } + }, + }, + "redeclared field cannot change has-value": { + config: &struct { + MyField1 string `name:"my-field1"` + MyField2 bool `name:"my-field1"` + }{}, + expectedError: `^invalid field 'struct \{ MyField1 string "name:\\"my-field1\\""; MyField2 bool "name:\\"my-field1\\"" }.MyField2': incompatible field types detected \(is one a bool and another isn't\?\)$`, + }, + "redeclared field cannot change value-name": { + config: &struct { + MyField1 string `name:"my-field1" value-name:"V1"` + MyField2 string `name:"my-field1" value-name:"V2"` + }{}, + expectedError: `^invalid field 'struct \{ MyField1 string "name:\\"my-field1\\" value-name:\\"V1\\""; MyField2 string "name:\\"my-field1\\" value-name:\\"V2\\"" }.MyField2': invalid tag 'value-name=V2': cannot redefine value name$`, + }, + "redeclared field can set value-name": { + config: &struct { + MyField1 string `name:"my-field"` + MyField2 string `name:"my-field" value-name:"VVV"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + HasValue: true, + ValueName: ptrOf("VVV"), + }, + Targets: []reflect.Value{ + reflect.ValueOf(tc.config).Elem().FieldByName("MyField1"), + reflect.ValueOf(tc.config).Elem().FieldByName("MyField2"), + }, + }, + } + }, + }, + "redeclared field cannot change description": { + config: &struct { + MyField1 string `name:"my-field1" desc:"V1"` + MyField2 string `name:"my-field1" desc:"V2"` + }{}, + expectedError: `^invalid field 'struct \{ MyField1 string "name:\\"my-field1\\" desc:\\"V1\\""; MyField2 string "name:\\"my-field1\\" desc:\\"V2\\"" }.MyField2': invalid tag 'desc=V2': cannot redefine description$`, + }, + "redeclared field can set description": { + config: &struct { + MyField1 string `name:"my-field"` + MyField2 string `name:"my-field" desc:"DESC"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + HasValue: true, + Description: ptrOf("DESC"), + }, + Targets: []reflect.Value{ + reflect.ValueOf(tc.config).Elem().FieldByName("MyField1"), + reflect.ValueOf(tc.config).Elem().FieldByName("MyField2"), + }, + }, + } + }, + }, + "redeclared field cannot change required status": { + config: &struct { + MyField1 string `name:"my-field1" required:"true"` + MyField2 string `name:"my-field1" required:"false"` + }{}, + expectedError: `^invalid field 'struct \{ MyField1 string "name:\\"my-field1\\" required:\\"true\\""; MyField2 string "name:\\"my-field1\\" required:\\"false\\"" }.MyField2': invalid tag 'required=false': cannot redefine required status$`, + }, + "redeclared field can set required status": { + config: &struct { + MyField1 string `name:"my-field"` + MyField2 string `name:"my-field" required:"true"` + }{}, + expectedFlags: func(tc *testCase) []*flagDef { + return []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + HasValue: true, + Required: ptrOf(true), + }, + Targets: []reflect.Value{ + reflect.ValueOf(tc.config).Elem().FieldByName("MyField1"), + reflect.ValueOf(tc.config).Elem().FieldByName("MyField2"), + }, + }, + } + }, + }, + "redeclared field cannot change default value": { + config: &struct { + MyField1 string `name:"my-field1"` + MyField2 string `name:"my-field1"` + }{ + MyField1: "v1", + MyField2: "v2", + }, + expectedError: `^invalid field 'struct \{ MyField1 string "name:\\"my-field1\\""; MyField2 string "name:\\"my-field1\\"" }.MyField2': incompatible default values detected: 'v1' vs 'v2'$`, + }, + "redeclared field cannot change inherited status": { + config: &struct { + MyField1 string `name:"my-field1" inherited:"true"` + MyField2 string `name:"my-field1" inherited:"false"` + }{}, + expectedError: `^invalid field 'struct \{ MyField1 string "name:\\"my-field1\\" inherited:\\"true\\""; MyField2 string "name:\\"my-field1\\" inherited:\\"false\\"" }.MyField2': incompatible inherited status detected: 'true' vs 'false'$`, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + valueOfConfig := reflect.ValueOf(tc.config) + if tc.expectedError != "" { + With(t).Verify(newFlagSet(nil, valueOfConfig)).Will(Fail(tc.expectedError)).OrFail() + } else { + fs, err := newFlagSet(nil, valueOfConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + if tc.expectedFlags != nil { + expectedFlags := tc.expectedFlags(&tc) + With(t). + Verify(fs.flags). + Will(EqualTo( + expectedFlags, + cmp.AllowUnexported(flagDef{}), + cmpopts.SortSlices(func(a *flagDef, b *flagDef) bool { return stdcmp.Less(a.Name, b.Name) }), + )). + OrFail() + } else { + With(t).Verify(fs.flags).Will(BeNil()).OrFail() + } + if tc.expectedPositionalsTargets != nil { + With(t).Verify(fs.positionalsTargets).Will(EqualTo(tc.expectedPositionalsTargets(&tc))).OrFail() + } else { + With(t).Verify(fs.positionalsTargets).Will(BeNil()).OrFail() + } + } + }) + } +} + +func TestFlagSetGetMergedFlagDefs(t *testing.T) { + t.Parallel() + type testCase struct { + parentConfig any + config any + expectedError string + expectedFlags func(tc *testCase) []*mergedFlagDef + } + testCases := map[string]testCase{ + "no parent": { + config: &struct { + F string `name:"my-field" env:"MY_FIELD" desc:"desc" inherited:"true"` + S struct { + F string `name:"my-field" value-name:"VVV" required:"true" inherited:"true"` + } + }{ + F: "abc", + S: struct { + F string `name:"my-field" value-name:"VVV" required:"true" inherited:"true"` + }{F: "abc"}, + }, + expectedFlags: func(tc *testCase) []*mergedFlagDef { + return []*mergedFlagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + EnvVarName: ptrOf("MY_FIELD"), + HasValue: true, + ValueName: ptrOf("VVV"), + Description: ptrOf("desc"), + Required: ptrOf(true), + DefaultValue: "abc", + }, + flagDefs: []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field", + EnvVarName: ptrOf("MY_FIELD"), + HasValue: true, + ValueName: ptrOf("VVV"), + Required: ptrOf(true), + Description: ptrOf("desc"), + DefaultValue: "abc", + }, + Inherited: true, + Targets: []reflect.Value{ + reflect.ValueOf(tc.config).Elem().FieldByName("F"), + reflect.ValueOf(tc.config).Elem().FieldByName("S").FieldByName("F"), + }, + }, + }, + }, + } + }, + }, + "flags merged across parents": { + parentConfig: &struct { + F1 string `name:"my-field1" env:"MF1" value-name:"VVV" inherited:"true"` + }{F1: "v1"}, + config: &struct { + F1 string `name:"my-field1"` + F11 string `name:"my-field1" desc:"desc1"` + F2 string `name:"my-field2" env:"MF2" desc:"desc2"` + }{ + F1: "v1", + F11: "v1", + F2: "v2", + }, + expectedFlags: func(tc *testCase) []*mergedFlagDef { + return []*mergedFlagDef{ + { + flagInfo: flagInfo{ + Name: "my-field1", + EnvVarName: ptrOf("MF1"), + HasValue: true, + ValueName: ptrOf("VVV"), + Description: ptrOf("desc1"), + Required: ptrOf(false), + DefaultValue: "v1", + }, + flagDefs: []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field1", + HasValue: true, + Description: ptrOf("desc1"), + DefaultValue: "v1", + }, + Inherited: false, + Targets: []reflect.Value{ + reflect.ValueOf(tc.config).Elem().FieldByName("F1"), + reflect.ValueOf(tc.config).Elem().FieldByName("F11"), + }, + }, + { + flagInfo: flagInfo{ + Name: "my-field1", + EnvVarName: ptrOf("MF1"), + HasValue: true, + ValueName: ptrOf("VVV"), + DefaultValue: "v1", + }, + Inherited: true, + Targets: []reflect.Value{reflect.ValueOf(tc.parentConfig).Elem().FieldByName("F1")}, + }, + }, + }, + { + flagInfo: flagInfo{ + Name: "my-field2", + EnvVarName: ptrOf("MF2"), + HasValue: true, + ValueName: ptrOf("VALUE"), + Description: ptrOf("desc2"), + Required: ptrOf(false), + DefaultValue: "v2", + }, + flagDefs: []*flagDef{ + { + flagInfo: flagInfo{ + Name: "my-field2", + EnvVarName: ptrOf("MF2"), + HasValue: true, + Description: ptrOf("desc2"), + DefaultValue: "v2", + }, + Targets: []reflect.Value{reflect.ValueOf(tc.config).Elem().FieldByName("F2")}, + }, + }, + }, + } + }, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + var parent *flagSet + if tc.parentConfig != nil { + valueOfParentConfig := reflect.ValueOf(tc.parentConfig) + fs, err := newFlagSet(nil, valueOfParentConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + parent = fs + } + valueOfConfig := reflect.ValueOf(tc.config) + if tc.expectedError != "" { + With(t).Verify(newFlagSet(parent, valueOfConfig)).Will(Fail(tc.expectedError)).OrFail() + } else { + fs, err := newFlagSet(parent, valueOfConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + if tc.expectedFlags != nil { + mergedFlagDefs, err := fs.getMergedFlagDefs() + With(t).Verify(err).Will(BeNil()).OrFail() + With(t). + Verify(mergedFlagDefs). + Will(EqualTo(tc.expectedFlags(&tc), cmp.AllowUnexported(flagDef{}, mergedFlagDef{}))).OrFail() + } else { + With(t).Verify(fs.flags).Will(BeNil()).OrFail() + } + } + }) + } +} + +func TestFlagSetUsagePrinting(t *testing.T) { + t.Parallel() + type testCase struct { + parentConfig any + config any + width int + expectedSingleLineUsage string + expectedMultiLineUsage string + } + testCases := map[string]testCase{ + "no parent, single flag for multiple fields in nested structure": { + config: &struct { + F string `name:"my-field" env:"MY_FIELD" desc:"desc" inherited:"true"` + S struct { + F string `name:"my-field" value-name:"VVV" required:"true" inherited:"true"` + } + }{ + F: "abc", + S: struct { + F string `name:"my-field" value-name:"VVV" required:"true" inherited:"true"` + }{F: "abc"}, + }, + expectedSingleLineUsage: `--my-field=VVV`, + expectedMultiLineUsage: ` +--my-field=VVV desc (default value: abc, environment variable: + MY_FIELD) +`, + }, + "flags merged across parents": { + parentConfig: &struct { + F1 string `name:"my-field1" env:"MF1" value-name:"VVV" inherited:"true"` + }{F1: "v1"}, + config: &struct { + F1 string `name:"my-field1" required:"true"` + F11 string `name:"my-field1" desc:"desc1"` + F2 bool `name:"my-field2" env:"MF2" desc:"desc2"` + }{ + F1: "v1", + F11: "v1", + }, + expectedSingleLineUsage: `--my-field1=VVV [--my-field2]`, + expectedMultiLineUsage: ` +--my-field1=VVV desc1 (default value: v1, environment variable: + MF1) +[--my-field2] desc2 (default value: false, environment + variable: MF2) +`, + }, + "positionals without flags": { + config: &struct { + Args []string `args:"true"` + }{}, + expectedSingleLineUsage: `[ARGS...]`, + expectedMultiLineUsage: ` +`, + }, + "flags and positionals": { + config: &struct { + F1 string `name:"my-field1" value-name:"FF"` + F2 bool `name:"my-field2" env:"MF2" desc:"desc2"` + Args []string `args:"true"` + }{ + F1: "v1", + }, + expectedSingleLineUsage: `[--my-field1=FF] [--my-field2] [ARGS...]`, + expectedMultiLineUsage: ` +[--my-field1=FF] default value: v1, environment variable: MY_FIELD1 +[--my-field2] desc2 (default value: false, environment + variable: MF2) +`, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + var parent *flagSet + if tc.parentConfig != nil { + valueOfParentConfig := reflect.ValueOf(tc.parentConfig) + fs, err := newFlagSet(nil, valueOfParentConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + parent = fs + } + valueOfConfig := reflect.ValueOf(tc.config) + + fs, err := newFlagSet(parent, valueOfConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + + width := tc.width + if width == 0 { + width = 70 + } + + singleLine := &bytes.Buffer{} + With(t).Verify(fs.printFlagsSingleLine(singleLine)).Will(Succeed()).OrFail() + With(t).Verify(singleLine.String()).Will(EqualTo(tc.expectedSingleLineUsage)).OrFail() + + multiLine, err := NewWrappingWriter(width) + With(t).Verify(err).Will(BeNil()).OrFail() + With(t).Verify(fs.printFlagsMultiLine(multiLine, "")).Will(Succeed()).OrFail() + With(t).Verify(multiLine.String()).Will(EqualTo(tc.expectedMultiLineUsage[1:])).OrFail() + }) + } +} + +func TestFlagSetApply(t *testing.T) { + t.Parallel() + type testCase struct { + parentConfig any + config any + envVars map[string]string + args []string + expectedParentConfig any + expectedConfig any + expectedError string + } + testCases := map[string]testCase{ + "CLI overrides environment variables": { + config: &struct { + F1 string `name:"my-field1"` + }{}, + envVars: map[string]string{ + "MY_FIELD1": "should not be used", + }, + args: []string{"--my-field1=CLI value for F1"}, + expectedConfig: &struct { + F1 string `name:"my-field1"` + }{F1: "CLI value for F1"}, + }, + "correct environment variable used for flag": { + config: &struct { + F1 string `name:"my-field1" env:"MF1"` + }{}, + envVars: map[string]string{ + "MY_FIELD1": "should not be used", + "MF1": "correct value for F1", + }, + args: []string{}, + expectedConfig: &struct { + F1 string `name:"my-field1" env:"MF1"` + }{F1: "correct value for F1"}, + }, + "default value preserved": { + config: &struct { + F1 string `name:"my-field1" env:"MF1"` + F2 string `name:"my-field2"` + F3 string `name:"my-field3"` + F4 string `name:"my-field4"` + }{F1: "default1", F2: "default2", F3: "default3", F4: "default4"}, + envVars: map[string]string{ + "MY_FIELD1": "should not be used", + "MF1": "correct value for F1", + "MY_FIELD2": "correct value for F2", + }, + args: []string{"--my-field3=correct value for F3"}, + expectedConfig: &struct { + F1 string `name:"my-field1" env:"MF1"` + F2 string `name:"my-field2"` + F3 string `name:"my-field3"` + F4 string `name:"my-field4"` + }{ + F1: "correct value for F1", + F2: "correct value for F2", + F3: "correct value for F3", + F4: "default4", + }, + }, + "both flags and positionals applied": { + config: &struct { + F1 string `name:"my-field1" env:"MF1"` + Args []string `args:"true"` + }{}, + envVars: map[string]string{}, + args: []string{ + "--my-field1=correct value for F1", + "a", + "b", + "c", + }, + expectedConfig: &struct { + F1 string `name:"my-field1" env:"MF1"` + Args []string `args:"true"` + }{ + F1: "correct value for F1", + Args: []string{"a", "b", "c"}, + }, + }, + "invalid flag error": { + config: &struct { + F1 string `name:"my-field1"` + }{}, + envVars: map[string]string{}, + args: []string{"--my-field1=VVV1", "--my-field2=VVV2"}, + expectedError: `^unknown flag: --my-field2$`, + }, + "required field is missing error": { + config: &struct { + F1 string `name:"my-field1"` + F2 string `name:"my-field2" required:"true"` + }{F1: "v1"}, + envVars: map[string]string{}, + args: []string{"--my-field1=VVV1"}, + expectedError: `^required flag is missing: --my-field2$`, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + var parent *flagSet + if tc.parentConfig != nil { + valueOfParentConfig := reflect.ValueOf(tc.parentConfig) + fs, err := newFlagSet(nil, valueOfParentConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + parent = fs + } + valueOfConfig := reflect.ValueOf(tc.config) + + fs, err := newFlagSet(parent, valueOfConfig) + With(t).Verify(err).Will(BeNil()).OrFail() + + if tc.expectedError != "" { + With(t).Verify(fs.apply(tc.envVars, tc.args)).Will(Fail(tc.expectedError)).OrFail() + } else { + With(t).Verify(fs.apply(tc.envVars, tc.args)).Will(Succeed()).OrFail() + With(t).Verify(tc.parentConfig).Will(EqualTo(tc.expectedParentConfig)).OrFail() + With(t).Verify(tc.config).Will(EqualTo(tc.expectedConfig)).OrFail() + } + }) + } +} diff --git a/go.mod b/go.mod index ecc93d9..1e25c9c 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,16 @@ module github.com/arikkfir/command -go 1.21.0 +go 1.22.0 require ( github.com/arikkfir/justest v0.3.1 + github.com/go-loremipsum/loremipsum v1.1.3 github.com/google/go-cmp v0.6.0 + golang.org/x/sys v0.20.0 ) require ( github.com/alecthomas/chroma/v2 v2.13.0 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/stretchr/testify v1.8.4 // indirect ) diff --git a/go.sum b/go.sum index 7d38b50..bf6fe38 100644 --- a/go.sum +++ b/go.sum @@ -6,9 +6,21 @@ github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/arikkfir/justest v0.3.1 h1:zvdeXjyP6YmoTQRtF9On6Ykjp/XDBlQQQnwAhPKPg0o= github.com/arikkfir/justest v0.3.1/go.mod h1:zgaQvIBfgdiP2JZhsnaKzhxux84BxI8aI84Kc/OSDgE= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/go-loremipsum/loremipsum v1.1.3 h1:ZRhA0ZmJ49lGe5HhWeMONr+iGftWDsHfrYBl5ktDXso= +github.com/go-loremipsum/loremipsum v1.1.3/go.mod h1:OJQjXdvwlG9hsyhmMQoT4HOm4DG4l62CYywebw0XBoo= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/util.go b/util.go index c2b9871..ead1e0a 100644 --- a/util.go +++ b/util.go @@ -2,10 +2,31 @@ package command import ( "fmt" + "os" "strings" "unicode" + + "golang.org/x/sys/unix" ) +func ptrOf[T any](v T) *T { + return &v +} + +func defaultIfNil[T any](v *T, defaultValue T) T { + if v == nil { + return defaultValue + } + return *v +} + +func intForBool(b bool) int { + if b { + return 1 + } + return 0 +} + func fieldNameToFlagName(fieldName string) string { var result []rune for i, r := range fieldName { @@ -27,6 +48,10 @@ func fieldNameToFlagName(fieldName string) string { return string(result) } +func flagNameToEnvVarName(flagName string) string { + return strings.ReplaceAll(strings.ToUpper(flagName), "-", "_") +} + func fieldNameToEnvVarName(fieldName string) string { var result []rune for i, r := range fieldName { @@ -48,43 +73,7 @@ func fieldNameToEnvVarName(fieldName string) string { return string(result) } -func environmentVariableToFlagName(name string) string { - return strings.ReplaceAll(strings.ToLower(name), "_", "-") -} - -func inferCommandFlagsAndPositionals(root *Command, args []string) (*Command, []string, []string) { - var flagArgs []string - var positionalArgs []string - - cmd := root - onlyPositionalArgs := false - for i := 0; i < len(args); i++ { - arg := args[i] - - if onlyPositionalArgs { - positionalArgs = append(positionalArgs, arg) - } else if arg == "--" { - onlyPositionalArgs = true - } else if strings.HasPrefix(arg, "-") { - flagArgs = append(flagArgs, arg) - } else { - found := false - for _, subCmd := range cmd.subCommands { - if subCmd.Name == arg { - cmd = subCmd - found = true - break - } - } - if !found { - positionalArgs = append(positionalArgs, arg) - } - } - } - - return cmd, flagArgs, positionalArgs -} - +//goland:noinspection GoUnusedExportedFunction func EnvVarsArrayToMap(envVars []string) map[string]string { envVarsMap := make(map[string]string) for _, nameValue := range envVars { @@ -96,3 +85,12 @@ func EnvVarsArrayToMap(envVars []string) map[string]string { } return envVarsMap } + +func getTerminalWidth() int { + fd := int(os.Stdout.Fd()) + ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ) + if err != nil { + return 80 + } + return int(ws.Col) +} diff --git a/util_test.go b/util_test.go index 81bf827..2fb7a08 100644 --- a/util_test.go +++ b/util_test.go @@ -1,11 +1,9 @@ package command import ( - "strings" "testing" . "github.com/arikkfir/justest" - "github.com/google/go-cmp/cmp/cmpopts" ) func TestFieldNameToFlagName(t *testing.T) { @@ -49,84 +47,3 @@ func TestFieldNameToEnvVarName(t *testing.T) { }) } } - -func Test_inferCommandFlagsAndPositionals(t *testing.T) { - type testCase struct { - root *Command - args []string - expectedCommand *Command - expectedFlags []string - expectedPositionals []string - } - - rootCmd := New(nil, Spec{ - Name: "root", - ShortDescription: "Root command", - LongDescription: "This command is the\nroot command.", - Config: &RootConfig{}, - }) - sub1Cmd := New(rootCmd, Spec{ - Name: "sub1", - ShortDescription: "Sub command 1", - LongDescription: "This command is the\nfirst sub command.", - Config: &Sub1Config{}, - }) - sub2Cmd := New(sub1Cmd, Spec{ - Name: "sub2", - ShortDescription: "Sub command 2", - LongDescription: "This command is the\nsecond sub command.", - Config: &Sub2Config{}, - }) - New(sub2Cmd, Spec{ - Name: "sub3", - ShortDescription: "Sub command 3", - LongDescription: "This command is the\nthird sub command.", - Config: &Sub3Config{}, - }) - - testCases := map[string]testCase{ - "No arguments": { - root: rootCmd, - expectedCommand: rootCmd, - expectedFlags: nil, - expectedPositionals: nil, - }, - "Flags for root command": { - root: rootCmd, - args: strings.Split("-f1 -f2", " "), - expectedCommand: rootCmd, - expectedFlags: []string{"-f1", "-f2"}, - expectedPositionals: nil, - }, - "Flags and positionals for root command": { - root: rootCmd, - args: strings.Split("-f1 a -f2 b", " "), - expectedCommand: rootCmd, - expectedFlags: []string{"-f1", "-f2"}, - expectedPositionals: []string{"a", "b"}, - }, - "Flags and positionals for sub1 command": { - root: rootCmd, - args: strings.Split("-f1 sub1 -f2 a b", " "), - expectedCommand: sub1Cmd, - expectedFlags: []string{"-f1", "-f2"}, - expectedPositionals: []string{"a", "b"}, - }, - "Flags and positionals for sub2 command": { - root: rootCmd, - args: strings.Split("-f1 sub1 -f2 a b sub2 c", " "), - expectedCommand: sub2Cmd, - expectedFlags: []string{"-f1", "-f2"}, - expectedPositionals: []string{"a", "b", "c"}, - }, - } - for name, tc := range testCases { - tc := tc - t.Run(name, func(t *testing.T) { - cmd, flags, pos := inferCommandFlagsAndPositionals(tc.root, tc.args) - With(t).Verify(cmd).Will(EqualTo(tc.expectedCommand, cmpopts.IgnoreUnexported(Command{}))).OrFail() - With(t).Verify(flags).Will(EqualTo(tc.expectedFlags)).OrFail() - With(t).Verify(pos).Will(EqualTo(tc.expectedPositionals)).OrFail() - }) - } -} diff --git a/wrapping_writer.go b/wrapping_writer.go new file mode 100644 index 0000000..dc69752 --- /dev/null +++ b/wrapping_writer.go @@ -0,0 +1,89 @@ +package command + +import ( + "fmt" + "strings" + "unicode" +) + +type WrappingWriter struct { + data []rune + width int + remainingToNextNewLine int + linePrefix string +} + +func NewWrappingWriter(width int) (*WrappingWriter, error) { + if width <= 0 { + return nil, fmt.Errorf("illegal width: %d", width) + } + return &WrappingWriter{data: nil, width: width, remainingToNextNewLine: width}, nil +} + +func (w *WrappingWriter) SetLinePrefix(prefix string) error { + if len(prefix) >= w.width { + return fmt.Errorf("invalid prefix '%s': too larger for width %d", prefix, w.width) + } else if strings.Contains(prefix, "\n") { + return fmt.Errorf("invalid prefix '%s': cannot contain new lines", prefix) + } + w.linePrefix = prefix + return nil +} + +func (w *WrappingWriter) Write(p []byte) (n int, err error) { + srcRunes := []rune(string(p)) + for i := 0; i < len(srcRunes); i++ { + r := srcRunes[i] + if r == '\n' { + if len(w.data) == 0 || (i > 0 && w.data[len(w.data)-1] == '\n') { + w.data = append(w.data, []rune(w.linePrefix)...) + } + w.data = append(w.data, r) + w.remainingToNextNewLine = w.width + } else if w.remainingToNextNewLine == 0 { + for j := len(w.data) - 1; j >= 0; j-- { + rr := w.data[j] + if rr == '\n' { + // Current line has no space; just keep writing this line without splitting it + w.data = append(w.data, r) + break + } else if len(w.data)-j+len(w.linePrefix) >= w.width { + // current line is already at width-length (including prefix) - just keep writing + w.data = append(w.data, r) + break + } else if unicode.IsSpace(rr) { + var runesBeforeSpace, runesAfterSpace []rune + runesBeforeSpace = w.data[0 : j+1] + if j < len(w.data)-1 { + runesAfterSpace = w.data[j+1:] + } + w.data = make([]rune, 0, len(runesBeforeSpace)+len(runesAfterSpace)+1) + w.data = append(w.data, runesBeforeSpace...) + w.data = append(w.data, '\n') + w.data = append(w.data, []rune(w.linePrefix)...) + w.data = append(w.data, runesAfterSpace...) + w.data = append(w.data, r) + + // Remaining characters now equal width minus text after last space, minus the char we just wrote + w.remainingToNextNewLine = w.width - len(w.linePrefix) - len(runesAfterSpace) - 1 + if w.remainingToNextNewLine < 0 { + w.remainingToNextNewLine = 0 + } + break + } + } + } else { + if len(w.data) == 0 || w.data[len(w.data)-1] == '\n' { + w.data = append(w.data, []rune(w.linePrefix)...) + w.remainingToNextNewLine -= len(w.linePrefix) + } + w.data = append(w.data, r) + w.remainingToNextNewLine-- + } + } + return len(p), nil +} + +func (w *WrappingWriter) String() string { + return string(w.data) +} diff --git a/wrapping_writer_test.go b/wrapping_writer_test.go new file mode 100644 index 0000000..4d5d40b --- /dev/null +++ b/wrapping_writer_test.go @@ -0,0 +1,281 @@ +package command + +import ( + "testing" + + . "github.com/arikkfir/justest" +) + +func TestWrappingWriter(t *testing.T) { + t.Parallel() + type testCase struct { + inputs [][]byte + width int + prefix string + expectedString string + } + testCases := map[string]testCase{ + "single input, simple single line under width": { + inputs: [][]byte{ + []byte("hello world"), + }, + width: 80, + expectedString: ` +hello world +`, + }, + "single input, multi-line, all lines under width": { + inputs: [][]byte{ + []byte("hello world\ntest test test\none two three"), + }, + width: 80, + expectedString: ` +hello world +test test test +one two three +`, + }, + "single input, multi-line, 1st line over width": { + inputs: [][]byte{ + []byte("hello world\ntest test\none two"), + }, + width: 10, + expectedString: ` +hello +world +test test +one two +`, + }, + "multi-input, multi-line, 1st line over width": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo wor"), + []byte("ld\ntest "), + []byte("test\none two"), + }, + width: 10, + expectedString: ` +hello +world +test test +one two +`, + }, + "multi-input, multi-line, 2nd line over width": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\ntesting "), + []byte("test\none two"), + }, + width: 10, + expectedString: ` +hello +testing +test +one two +`, + }, + "multi-input, multi-line, 2nd line over width, special symbols": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\nabc -"), + []byte("-key=v\none two"), + }, + width: 10, + expectedString: ` +hello +abc +--key=v +one two +`, + }, + "multi-input, multi-line, 2nd line over width, split with hard break": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\nabc -"), + []byte("-very-long-key=v\none two"), + }, + width: 10, + expectedString: ` +hello +abc +--very-long-key=v +one two +`, + }, + "multi-input, multi-line, 2nd line over width & cannot be broken": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\n--very-long-key=v\none two"), + }, + width: 10, + expectedString: ` +hello +--very-long-key=v +one two +`, + }, + "multi-input, multi-line, 2nd line splits exactly on width": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\n--very=v12\none two"), + }, + width: 10, + expectedString: ` +hello +--very=v12 +one two +`, + }, + "prefixed single input, simple single line under width": { + inputs: [][]byte{ + []byte("hello world"), + }, + width: 80, + prefix: " ", + expectedString: ` + hello world +`, + }, + "prefixed single input, multi-line, all lines under width": { + inputs: [][]byte{ + []byte("hello world\ntest test test\none two three"), + }, + width: 80, + prefix: " ", + expectedString: ` + hello world + test test test + one two three +`, + }, + "prefixed single input, multi-line, 1st line over width": { + inputs: [][]byte{ + []byte("hello world\ntest test\none two"), + }, + width: 10, + prefix: " ", + expectedString: ` + hello + world + test + test + one + two +`, + }, + "prefixed multi-input, multi-line, 1st line over width": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo wor"), + []byte("ld\ntest "), + []byte("test\none two"), + }, + width: 10, + prefix: " ", + expectedString: ` + hello + world + test + test + one + two +`, + }, + "prefixed multi-input, multi-line, 2nd line over width": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\ntesting "), + []byte("test\none two"), + }, + width: 10, + prefix: " ", + expectedString: ` + hello + testing + test + one + two +`, + }, + "prefixed multi-input, multi-line, 2nd line over width, special symbols": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\nabc -"), + []byte("-key=v\none two"), + }, + width: 10, + prefix: " ", + expectedString: ` + hello + abc + --key=v + one + two +`, + }, + "prefixed multi-input, multi-line, 2nd line over width, split with hard break": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\nabc -"), + []byte("-very-long-key=v\none two"), + }, + width: 10, + prefix: " ", + expectedString: ` + hello + abc + --very-long-key=v + one + two +`, + }, + "prefixed multi-input, multi-line, 2nd line over width & cannot be broken": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\n--very-long-key=v\none two"), + }, + width: 10, + prefix: " ", + expectedString: ` + hello + --very-long-key=v + one + two +`, + }, + "prefixed multi-input, multi-line, 2nd line splits exactly on width": { + inputs: [][]byte{ + []byte("hel"), + []byte("lo\n--very=v12\none two"), + }, + width: 10, + prefix: " ", + expectedString: ` + hello + --very=v12 + one + two +`, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + w, err := NewWrappingWriter(tc.width) + With(t).Verify(err).Will(BeNil()).OrFail() + if tc.prefix != "" { + With(t).Verify(w.SetLinePrefix(tc.prefix)).Will(Succeed()).OrFail() + } + + for _, input := range tc.inputs { + With(t).Verify(w.Write(input)).Will(Succeed()).OrFail() + } + + With(t).Verify(w.String()).Will(EqualTo(tc.expectedString[1 : len(tc.expectedString)-1])).OrFail() + }) + } +}