diff --git a/command.go b/command.go index bf53341..9a7d23b 100644 --- a/command.go +++ b/command.go @@ -6,7 +6,6 @@ import ( "flag" "fmt" "io" - "os" "reflect" "regexp" "slices" @@ -18,10 +17,7 @@ import ( var Version = "0.0.0-unknown" var ( - tokenRE = regexp.MustCompile(`^([^=]+)=(.*)$`) - builtinConfig = &BuiltinConfig{ - Help: false, - } + tokenRE = regexp.MustCompile(`^([^=]+)=(.*)$`) ) func New(parent *Command, spec Spec) *Command { @@ -33,12 +29,14 @@ func New(parent *Command, spec Spec) *Command { ShortDescription: spec.ShortDescription, LongDescription: spec.LongDescription, Config: spec.Config, + PreSubCommandRun: spec.OnSubCommandRun, Run: spec.Run, - Parent: parent, + builtinConfig: &BuiltinConfig{Help: false}, + parent: parent, createdByNewCommand: true, } - if cmd.Parent != nil { - cmd.Parent.subCommands = append(cmd.Parent.subCommands, cmd) + if cmd.parent != nil { + cmd.parent.subCommands = append(cmd.parent.subCommands, cmd) } return cmd } @@ -48,6 +46,7 @@ type Spec struct { ShortDescription string LongDescription string Config any + OnSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error } @@ -55,10 +54,12 @@ type Command struct { Name string ShortDescription string LongDescription string - Parent *Command - subCommands []*Command Config any + PreSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error + builtinConfig any + parent *Command + subCommands []*Command createdByNewCommand bool envVarsMapping map[string]reflect.Value flagSet *flag.FlagSet @@ -91,12 +92,12 @@ func (c *Command) initializeFlagSet() error { // Create a flag set name := c.Name - for parent := c.Parent; parent != nil; parent = parent.Parent { + for parent := c.parent; parent != nil; parent = parent.parent { name = parent.Name + " " + name } c.flagSet = flag.NewFlagSet(name, flag.ContinueOnError) c.flagSet.SetOutput(io.Discard) - if err := c.initializeFlagSetFromStruct(reflect.ValueOf(builtinConfig).Elem()); err != nil { + if err := c.initializeFlagSetFromStruct(reflect.ValueOf(c.builtinConfig).Elem()); err != nil { return fmt.Errorf("failed to process builtin configuration fields: %w", err) } @@ -261,19 +262,17 @@ func (c *Command) applyEnvironmentVariables(envVars map[string]string) error { return nil } -func (c *Command) configure(envVars map[string]string, args []string) error { +func (c *Command) applyCLIArguments(args []string) error { - // Apply environment variables first - if err := c.applyEnvironmentVariables(envVars); err != nil { - return fmt.Errorf("failed to apply environment variables: %w", err) - } - - // Override with CLI arguments + // Update config with CLI arguments if err := c.flagSet.Parse(args); err != nil { return fmt.Errorf("failed to apply CLI arguments: %w", err) } - // Ensure all required flags have been provided via either CLI or via environment variables + return nil +} + +func (c *Command) validateRequiredFlagsWereProvided(envVars map[string]string) error { var missingRequiredFlags []string copy(missingRequiredFlags, c.requiredFlags) c.flagSet.Visit(func(f *flag.Flag) { @@ -289,20 +288,40 @@ func (c *Command) configure(envVars map[string]string, args []string) error { }) } } + if len(missingRequiredFlags) > 0 { + return fmt.Errorf("these required flags have not set via either CLI nor environment variables: %v", missingRequiredFlags) + } + return nil +} + +func (c *Command) configure(envVars map[string]string, args []string) error { + + // Initialize the flagSet for the chosen command + if err := c.initializeFlagSet(); err != nil { + panic(fmt.Sprintf("failed to initialize flag set for command '%s': %v", c.Name, err)) + } + + // Apply environment variables first + if err := c.applyEnvironmentVariables(envVars); err != nil { + return fmt.Errorf("failed to apply environment variables: %w", err) + } + + // Override with CLI arguments + if err := c.flagSet.Parse(args); err != nil { + return fmt.Errorf("failed to apply CLI arguments: %w", err) + } // Apply positional arguments if c.positionalArgsTarget != nil { *c.positionalArgsTarget = c.flagSet.Args() } - if len(missingRequiredFlags) > 0 { - return fmt.Errorf("these required flags have not set via either CLI nor environment variables: %v", missingRequiredFlags) - } + return nil } func (c *Command) printCommandUsage(w io.Writer, short bool) { cmdChain := c.Name - for cmd := c.Parent; cmd != nil; cmd = cmd.Parent { + for cmd := c.parent; cmd != nil; cmd = cmd.parent { cmdChain = cmd.Name + " " + cmdChain } @@ -392,10 +411,9 @@ func (c *Command) printCommandUsage(w io.Writer, short bool) { } } -//goland:noinspection GoUnusedExportedFunction -func Execute(root *Command, args []string, envVars map[string]string) { +func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) (exitCode int) { if !root.createdByNewCommand { - panic("illegal root command was specified - was it created by 'command.New(...)'?") + panic("invalid root command given, indicating it may not have been created by 'command.New(...)'") } // Iterate CLI args, separate them to flags & positional args, but also infer the command to execute from the given @@ -411,31 +429,44 @@ func Execute(root *Command, args []string, envVars map[string]string) { // positional args: [something, sub3, a, b, c]: no "cmd1", "sub1" and "sub2" as they are commands in the hierarchy cmd, flagArgs, positionalArgs := inferCommandFlagsAndPositionals(root, args) - // Initialize the flagSet for the chosen command - if err := cmd.initializeFlagSet(); err != nil { - panic(fmt.Sprintf("failed to initialize flag set for command '%s': %v", cmd.Name, err)) + // Build the command chain from top-to-bottom (so index 0 is the root) + commandChain := []*Command{cmd} + parent := cmd.parent + for parent != nil { + commandChain = append([]*Command{parent}, commandChain...) + parent = parent.parent } - // Parse the arguments as returned in the parsing step - if err := cmd.configure(envVars, append(flagArgs, positionalArgs...)); err != nil { - cmd.PrintShortUsage(os.Stderr) - os.Exit(1) - } else if cmd.flagSet.Lookup("help").Value.String() == "true" { - cmd.PrintFullUsage(os.Stderr) - os.Exit(0) + // Configure commands up the chain, in order to invoke their "PreSubCommandRun" function + for _, current := range commandChain { + if err := current.configure(envVars, append(flagArgs, positionalArgs...)); err != nil { + current.PrintShortUsage(w) + return 1 + } + + if err := current.PreSubCommandRun(ctx, current.Config, current); err != nil { + _, _ = fmt.Fprintln(w, err.Error()) + return 1 + } + } + + // If "--help" was provided, show usage and exit immediately + if cmd.flagSet.Lookup("help").Value.String() == "true" { + cmd.PrintFullUsage(w) + return 0 } // If command has no "Run" function, it's an intermediate probably - just print its usage and exit successfully if cmd.Run == nil { - cmd.PrintFullUsage(os.Stderr) + cmd.PrintFullUsage(w) + return 0 } - // Run the command with a fresh context - ctx, cancel := context.WithCancel(SetupSignalHandler()) - defer cancel() + // Run the command if err := cmd.Run(ctx, cmd.Config, cmd); err != nil { - cancel() // os.Exit might not invoke the deferred cancel call - _, _ = fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) + _, _ = fmt.Fprintln(w, err.Error()) + return 1 } + + return 0 } diff --git a/command_test.go b/command_test.go index f181ffa..285541b 100644 --- a/command_test.go +++ b/command_test.go @@ -2,6 +2,8 @@ package command import ( "bytes" + "context" + "regexp" "testing" . "github.com/arikkfir/justest" @@ -15,14 +17,14 @@ func Test_initializeFlagSet(t *testing.T) { DefValue string } type testCase struct { - cmd Command + cmd *Command expectedFlags []Flag expectedFailure *string } testCases := map[string]testCase{ - "nil Config": {cmd: Command{Config: nil}}, + "nil Config": {cmd: New(nil, Spec{Config: nil})}, "Config is a pointer to a struct": { - cmd: Command{Config: &RootConfig{S0: "v0"}}, + cmd: New(nil, Spec{Config: &RootConfig{S0: "v0"}}), expectedFlags: []Flag{ {Name: "s0", DefValue: "v0", Usage: "String field", Value: nil}, {Name: "b0", DefValue: "false", Usage: "Bool field", Value: nil}, @@ -48,15 +50,15 @@ func Test_initializeFlagSet(t *testing.T) { } t.Run("Config must be a pointer to a struct", func(t *testing.T) { With(t). - Verify((&Command{Config: RootConfig{S0: "v0"}}).initializeFlagSet()). + Verify(New(nil, Spec{Config: RootConfig{S0: "v0"}}).initializeFlagSet()). Will(Fail(`must not be a struct, but a pointer to a struct`)). OrFail() With(t). - Verify((&Command{Config: 1}).initializeFlagSet()). + Verify(New(nil, Spec{Config: 1}).initializeFlagSet()). Will(Fail(`is not a pointer: 1`)). OrFail() With(t). - Verify((&Command{Config: &[]int{123}[0]}).initializeFlagSet()). + Verify(New(nil, Spec{Config: &[]int{123}[0]}).initializeFlagSet()). Will(Fail(`is not a pointer to struct`)). OrFail() }) @@ -200,3 +202,133 @@ Flags: }) } } + +func TestExecute(t *testing.T) { + t.Parallel() + + rootSpec := Spec{Name: "root", ShortDescription: "Root!", LongDescription: "The root command.", Config: &RootConfig{}} + sub1Spec := Spec{Name: "sub1", ShortDescription: "Sub 1", LongDescription: "The first sub command.", Config: &Sub1Config{}} + sub2Spec := Spec{Name: "sub2", ShortDescription: "Sub 2", LongDescription: "The second sub command.", Config: &Sub2Config{}} + sub3Spec := Spec{Name: "sub3", ShortDescription: "Sub 3", LongDescription: "The third sub command.", Config: &Sub3Config{}} + + type testCase struct { + args []string + envVars []string + expectedExitCode int + expectedPreRunCalls map[string]bool + expectedRunCalls map[string]bool + expectedOutput string + } + testCases := map[string]testCase{ + "": { + args: nil, + envVars: nil, + expectedExitCode: 0, + expectedPreRunCalls: map[string]bool{"root": true}, + expectedRunCalls: map[string]bool{"root": true}, + }, + "sub1": { + args: []string{"sub1"}, + envVars: nil, + expectedExitCode: 0, + expectedPreRunCalls: map[string]bool{"root": true, "sub1": true}, + expectedRunCalls: map[string]bool{"sub1": true}, + }, + "sub2 sub3": { + args: []string{"sub2", "sub3"}, + envVars: nil, + expectedExitCode: 0, + expectedPreRunCalls: map[string]bool{"root": true, "sub2": true, "sub3": true}, + expectedRunCalls: map[string]bool{"sub3": true}, + }, + "sub2 sub3 --help": { + args: []string{"sub2", "sub3", "--help"}, + envVars: nil, + expectedExitCode: 0, + expectedPreRunCalls: map[string]bool{"root": true, "sub2": true, "sub3": true}, + expectedRunCalls: map[string]bool{}, + expectedOutput: `root sub2 sub3: Sub 3 + +The third sub command. + +Usage: + root sub2 sub3 [--b0] [--b1] [--b2] [--b3] [--help] --s0=VAL --s1=VALUE [--s2=VALUE] [--s3=VALUE] [ARGS] + +Flags: + --b0 Bool field (default is false) + --b1 Bool field (default is false) + --b2 Bool field (default is false) + --b3 Bool field (default is false) + --help Show help about how to use this command (default is false) + --s0 String field + --s1 String field + --s2 String field + --s3 String field + +`, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + preRunCalls := make(map[string]bool) + runCalls := make(map[string]bool) + + rootSpec := rootSpec + rootSpec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + preRunCalls["root"] = true + return nil + } + rootSpec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + runCalls["root"] = true + return nil + } + rootCmd := New(nil, rootSpec) + + sub1Spec := sub1Spec + sub1Spec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + preRunCalls["sub1"] = true + return nil + } + sub1Spec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + runCalls["sub1"] = true + return nil + } + sub1Cmd := New(rootCmd, sub1Spec) + + sub2Spec := sub2Spec + sub2Spec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + preRunCalls["sub2"] = true + return nil + } + sub2Spec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + runCalls["sub2"] = true + return nil + } + sub2Cmd := New(rootCmd, sub2Spec) + + sub3Spec := sub3Spec + sub3Spec.OnSubCommandRun = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + preRunCalls["sub3"] = true + return nil + } + sub3Spec.Run = func(ctx context.Context, config any, usagePrinter UsagePrinter) error { + runCalls["sub3"] = true + return nil + } + sub3Cmd := New(sub2Cmd, sub3Spec) + _, _, _, _ = rootCmd, sub1Cmd, sub2Cmd, sub3Cmd + + b := &bytes.Buffer{} + exitCode := Execute(context.Background(), b, rootCmd, tc.args, EnvVarsArrayToMap(tc.envVars)) + With(t).Verify(exitCode).Will(EqualTo(tc.expectedExitCode)).OrFail() + With(t).Verify(preRunCalls).Will(EqualTo(tc.expectedPreRunCalls)).OrFail() + With(t).Verify(runCalls).Will(EqualTo(tc.expectedRunCalls)).OrFail() + if tc.expectedOutput != "" { + With(t).Verify(b).Will(Say(`^` + regexp.QuoteMeta(tc.expectedOutput) + `$`)).OrFail() + } + }) + } +}