From 3e47d12ed0cee1249aa7a744bce1c5419bc9a1f9 Mon Sep 17 00:00:00 2001 From: Arik Kfir Date: Fri, 14 Jun 2024 14:39:42 +0300 Subject: [PATCH] feat(command): add support for post-run command hooks This change adds support for providing post-run command hooks, enabling commands to be (when needed) initialization & tear-down wrappers around concrete sub-commands. Unlike pre-run hooks (which are run in order, depth-last, before actual command execution), post-run hooks are run depth-first (reverse order). Both the pre & post run hooks can be structs that have their own flags and configurations. --- command.go | 22 +++++++- command_test.go | 69 ++++++++++++----------- execute.go | 59 ++++++++++++++----- execute_test.go | 146 +++++++++++++++++++++++++++++++++++++++++------- 4 files changed, 228 insertions(+), 68 deletions(-) diff --git a/command.go b/command.go index c297139..97a8182 100644 --- a/command.go +++ b/command.go @@ -47,6 +47,20 @@ func (i PreRunHookFunc) PreRun(ctx context.Context) error { } } +type PostRunHook interface { + PostRun(context.Context, error, ExitCode) error +} + +type PostRunHookFunc func(context.Context, error, ExitCode) error + +func (i PostRunHookFunc) PostRun(ctx context.Context, err error, exitCode ExitCode) error { + if i != nil { + return i(ctx, err, exitCode) + } else { + return nil + } +} + // Command is a command instance, created by [New] and can be composed with more Command instances to form a CLI command // hierarchy. type Command struct { @@ -54,6 +68,7 @@ type Command struct { shortDescription string longDescription string preRunHooks []PreRunHook + postRunHooks []PostRunHook action Action flags *flagSet parent *Command @@ -64,8 +79,8 @@ type Command struct { // MustNew creates a new command using [New], but will panic if it returns an error. // //goland:noinspection GoUnusedExportedFunction -func MustNew(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, subCommands ...*Command) *Command { - cmd, err := New(name, shortDescription, longDescription, action, preRunHooks, subCommands...) +func MustNew(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, postRunHooks []PostRunHook, subCommands ...*Command) *Command { + cmd, err := New(name, shortDescription, longDescription, action, preRunHooks, postRunHooks, subCommands...) if err != nil { panic(err) } @@ -74,7 +89,7 @@ func MustNew(name, shortDescription, longDescription string, action Action, preR // New creates a new command with the given name, short & long descriptions, and the given executor. The executor object // is also scanned for configuration structs via reflection. -func New(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, subCommands ...*Command) (*Command, error) { +func New(name, shortDescription, longDescription string, action Action, preRunHooks []PreRunHook, postRunHooks []PostRunHook, subCommands ...*Command) (*Command, error) { if name == "" { return nil, fmt.Errorf("%w: empty name", ErrInvalidCommand) } else if shortDescription == "" { @@ -88,6 +103,7 @@ func New(name, shortDescription, longDescription string, action Action, preRunHo longDescription: longDescription, action: action, preRunHooks: preRunHooks, + postRunHooks: postRunHooks, HelpConfig: &HelpConfig{}, } diff --git a/command_test.go b/command_test.go index 718af4c..9a468ef 100644 --- a/command_test.go +++ b/command_test.go @@ -25,19 +25,19 @@ func TestNew(t *testing.T) { testCases := map[string]testCase{ "empty name": { commandFactory: func(t T, tc *testCase) (*Command, error) { - return New("", "short desc", "long desc", nil, nil) + return New("", "short desc", "long desc", nil, nil, nil) }, expectedError: `^invalid command: empty name$`, }, "empty short description": { commandFactory: func(t T, tc *testCase) (*Command, error) { - return New("cmd", "", "long desc", nil, nil) + return New("cmd", "", "long desc", nil, nil, nil) }, expectedError: `^invalid command: empty short description$`, }, "no flags": { commandFactory: func(t T, tc *testCase) (*Command, error) { - return New("cmd", "desc", "long desc", nil, nil) + return New("cmd", "desc", "long desc", nil, nil, nil) }, expectedName: "cmd", expectedShortDescription: "desc", @@ -54,6 +54,7 @@ func TestNew(t *testing.T) { MyFlag string `flag:"true"` }{}, nil, + nil, ) }, expectedFlagSet: &flagSet{ @@ -96,13 +97,13 @@ func TestNew(t *testing.T) { func TestAddSubCommand(t *testing.T) { t.Parallel() - root, err := New("root", "desc", "description", nil, nil) + root, err := New("root", "desc", "description", nil, nil, nil) With(t).Verify(err).Will(BeNil()).OrFail() - sub1, err := New("sub1", "sub1 desc", "sub1 description", nil, nil) + sub1, err := New("sub1", "sub1 desc", "sub1 description", nil, nil, nil) With(t).Verify(err).Will(BeNil()).OrFail() - sub2, err := New("sub2", "sub2 desc", "sub2 description", nil, nil) + sub2, err := New("sub2", "sub2 desc", "sub2 description", nil, nil, nil) With(t).Verify(err).Will(BeNil()).OrFail() With(t).Verify(root.AddSubCommand(sub1)).Will(BeNil()).OrFail() @@ -123,10 +124,10 @@ func Test_inferCommandAndArgs(t *testing.T) { testCases := map[string]testCase{ "No arguments": { root: MustNew( - "root", "desc", "description", nil, nil, - MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, - MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, - MustNew("sub3", "sub3 desc", "sub3 description", nil, nil), + "root", "desc", "description", nil, nil, nil, + MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, + MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil, + MustNew("sub3", "sub3 desc", "sub3 description", nil, nil, nil), ), ), ), @@ -137,9 +138,9 @@ func Test_inferCommandAndArgs(t *testing.T) { }, "Flags for root command": { root: MustNew( - "root", "desc", "description", nil, nil, - MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, - MustNew("sub2", "sub2 desc", "sub2 description", nil, nil), + "root", "desc", "description", nil, nil, nil, + MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, + MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil), ), ), args: strings.Split("-f1 -f2", " "), @@ -149,9 +150,9 @@ func Test_inferCommandAndArgs(t *testing.T) { }, "Flags and positionals for root command": { root: MustNew( - "root", "desc", "description", nil, nil, - MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, - MustNew("sub2", "sub2 desc", "sub2 description", nil, nil), + "root", "desc", "description", nil, nil, nil, + MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, + MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil), ), ), args: strings.Split("-f1 a -f2 b", " "), @@ -161,9 +162,9 @@ func Test_inferCommandAndArgs(t *testing.T) { }, "Flags and positionals for sub1 command": { root: MustNew( - "root", "desc", "description", nil, nil, - MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, - MustNew("sub2", "sub2 desc", "sub2 description", nil, nil), + "root", "desc", "description", nil, nil, nil, + MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, + MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil), ), ), args: strings.Split("-f1 sub1 -f2 a b", " "), @@ -173,9 +174,9 @@ func Test_inferCommandAndArgs(t *testing.T) { }, "Flags and positionals for sub2 command": { root: MustNew( - "root", "desc", "description", nil, nil, - MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, - MustNew("sub2", "sub2 desc", "sub2 description", nil, nil), + "root", "desc", "description", nil, nil, nil, + MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, + MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil), ), ), args: strings.Split("-f1 sub1 -f2 a b sub2 c", " "), @@ -200,10 +201,10 @@ func Test_getFullName(t *testing.T) { cmd *Command expectedFullName string } - sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil) - sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, sub3) - sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, sub2) - root := MustNew("root", "desc", "description", nil, nil, sub1) + sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil, nil) + sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil, sub3) + sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, sub2) + root := MustNew("root", "desc", "description", nil, nil, nil, sub1) testCases := map[string]testCase{ "root": { cmd: root, @@ -235,10 +236,10 @@ func Test_getChain(t *testing.T) { cmd *Command expectedChain []string } - sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil) - sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, sub3) - sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, sub2) - root := MustNew("root", "desc", "description", nil, nil, sub1) + sub3 := MustNew("sub3", "sub3 desc", "sub3 description", nil, nil, nil) + sub2 := MustNew("sub2", "sub2 desc", "sub2 description", nil, nil, nil, sub3) + sub1 := MustNew("sub1", "sub1 desc", "sub1 description", nil, nil, nil, sub2) + root := MustNew("root", "desc", "description", nil, nil, nil, sub1) testCases := map[string]testCase{ "root": { cmd: root, @@ -281,7 +282,7 @@ func TestPrintHelp(t *testing.T) { "no flags & no positionals": { commandFactory: func(*testCase) *Command { ligen := loremipsum.NewWithSeed(4321) - return MustNew("cmd", ligen.Sentence(), ligen.Sentences(2), nil, nil) + return MustNew("cmd", ligen.Sentence(), ligen.Sentences(2), nil, nil, nil) }, expectedHelpUsageOutput: ` Usage: cmd [--help] @@ -324,7 +325,9 @@ Flags: MyFlag string `desc:"flag description"` Args []string `args:"true"` }{}, - nil) + nil, + nil, + ) }, expectedHelpUsageOutput: ` Usage: cmd [--help] @@ -374,6 +377,7 @@ Flags: Args []string `args:"true"` }{}, nil, + nil, MustNew( "child1", ligen.Sentence(), @@ -384,6 +388,7 @@ Flags: Args []string `args:"true"` }{}, nil, + nil, ), ) }, diff --git a/execute.go b/execute.go index 4869431..30a6dad 100644 --- a/execute.go +++ b/execute.go @@ -18,10 +18,14 @@ const ( // Execute the correct command in the given command hierarchy (starting at "root"), configured from the given CLI args // and environment variables. The command will be executed with the given context after all pre-RunFunc hooks have been // successfully executed in the command hierarchy. -func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) ExitCode { +func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) (exitCode ExitCode) { + exitCode = ExitCodeSuccess + + // We insist on getting the root command - so that we can infer correctly which command the user wanted to invoke if root.parent != nil { _, _ = fmt.Fprintf(w, "%s: command must be the root command", errors.ErrUnsupported) - return ExitCodeError + exitCode = ExitCodeError + return } // Extract the command, CLI flags, positional arguments & the command hierarchy @@ -33,25 +37,51 @@ func Execute(ctx context.Context, w io.Writer, root *Command, args []string, env _, _ = fmt.Fprintln(w, err) if err := cmd.PrintUsageLine(w, getTerminalWidth()); err != nil { _, _ = fmt.Fprintf(w, "%s\n", err) - return ExitCodeError + exitCode = ExitCodeError + return } else { - return ExitCodeMisconfiguration + exitCode = ExitCodeMisconfiguration + return } } else if cmd.HelpConfig.Help { if err := cmd.PrintHelp(w, getTerminalWidth()); err != nil { _, _ = fmt.Fprintf(w, "%s\n", err) - return ExitCodeError + exitCode = ExitCodeMisconfiguration + return } else { - return ExitCodeSuccess + exitCode = ExitCodeSuccess + return } } + // Results + var actionError error + + // Ensure we invoke post-run hooks before we return + chain := cmd.getChain() + defer func() { + for i := len(chain) - 1; i >= 0; i-- { + c := chain[i] + for j := len(c.postRunHooks) - 1; j >= 0; j-- { + h := c.postRunHooks[j] + if err := h.PostRun(ctx, actionError, exitCode); err != nil { + _, _ = fmt.Fprintln(w, err) + exitCode = ExitCodeError + } + } + } + }() + // Invoke all "PreRun" hooks on the whole chain of commands (starting at the root) - for _, c := range cmd.getChain() { - for _, hook := range c.preRunHooks { - if err := hook.PreRun(ctx); err != nil { + for i := 0; i < len(chain); i++ { + c := chain[i] + for j := 0; j < len(c.preRunHooks); j++ { + h := c.preRunHooks[j] + if err := h.PreRun(ctx); err != nil { _, _ = fmt.Fprintln(w, err) - return ExitCodeError + actionError = err + exitCode = ExitCodeError + return } } } @@ -60,15 +90,16 @@ func Execute(ctx context.Context, w io.Writer, root *Command, args []string, env if cmd.action != nil { if err := cmd.action.Run(ctx); err != nil { _, _ = fmt.Fprintln(w, err) - return ExitCodeError + actionError = err + exitCode = ExitCodeError } } else { // Command is not a runner - print help if err := cmd.PrintHelp(w, getTerminalWidth()); err != nil { _, _ = fmt.Fprintf(w, "%s\n", err) - return ExitCodeError + actionError = err + exitCode = ExitCodeError } } - return ExitCodeSuccess - + return } diff --git a/execute_test.go b/execute_test.go index 0301d3d..073de16 100644 --- a/execute_test.go +++ b/execute_test.go @@ -3,11 +3,13 @@ package command import ( "bytes" "context" + "fmt" "os" "testing" "time" . "github.com/arikkfir/justest" + "github.com/google/go-cmp/cmp/cmpopts" ) type TrackingAction struct { @@ -32,6 +34,21 @@ func (a *TrackingPreRunHook) PreRun(_ context.Context) error { return a.errorToReturnOnCall } +type TrackingPostRunHook struct { + callTime *time.Time + providedActionError error + providedExitCode ExitCode + errorToReturnOnCall error +} + +func (a *TrackingPostRunHook) PostRun(_ context.Context, actionError error, exitCode ExitCode) error { + a.callTime = ptrOf(time.Now()) + a.providedActionError = actionError + a.providedExitCode = exitCode + time.Sleep(100 * time.Millisecond) + return a.errorToReturnOnCall +} + type ActionWithConfig struct { TrackingAction MyFlag string `name:"my-flag"` @@ -42,13 +59,18 @@ type PreRunHookWithConfig struct { MyFlag string `name:"my-flag"` } +type PostRunHookWithConfig struct { + TrackingPostRunHook + MyFlag string `name:"my-flag"` +} + func TestExecute(t *testing.T) { t.Parallel() t.Run("command must be root", func(t *testing.T) { ctx := context.Background() - child := MustNew("child", "desc", "long desc", nil, nil) - _ = MustNew("root", "desc", "long desc", nil, nil, child) + child := MustNew("child", "desc", "long desc", nil, nil, nil) + _ = MustNew("root", "desc", "long desc", nil, nil, nil, child) b := &bytes.Buffer{} With(t).Verify(Execute(ctx, b, child, nil, nil)).Will(EqualTo(ExitCodeError)).OrFail() With(t).Verify(b).Will(Say(`^unsupported operation: command must be the root command$`)).OrFail() @@ -56,14 +78,14 @@ func TestExecute(t *testing.T) { t.Run("applies configuration", func(t *testing.T) { ctx := context.Background() - cmd := MustNew("cmd", "desc", "long desc", &ActionWithConfig{}, nil) + cmd := MustNew("cmd", "desc", "long desc", &ActionWithConfig{}, nil, nil) With(t).Verify(Execute(ctx, os.Stderr, cmd, []string{"--my-flag=V1"}, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail() With(t).Verify(cmd.action.(*ActionWithConfig).MyFlag).Will(EqualTo("V1")).OrFail() }) t.Run("prints usage on CLI parse errors", func(t *testing.T) { ctx := context.Background() - cmd := MustNew("cmd", "desc", "long desc", &ActionWithConfig{}, nil) + cmd := MustNew("cmd", "desc", "long desc", &ActionWithConfig{}, nil, nil) b := &bytes.Buffer{} With(t).Verify(Execute(ctx, b, cmd, []string{"--bad-flag=V1"}, nil)).Will(EqualTo(ExitCodeMisconfiguration)).OrFail() With(t).Verify(cmd.action.(*ActionWithConfig).MyFlag).Will(BeEmpty()).OrFail() @@ -72,7 +94,7 @@ func TestExecute(t *testing.T) { t.Run("prints help on --help flag", func(t *testing.T) { ctx := context.Background() - cmd := MustNew("cmd", "desc", "long desc", &ActionWithConfig{}, nil) + cmd := MustNew("cmd", "desc", "long desc", &ActionWithConfig{}, nil, nil) b := &bytes.Buffer{} With(t).Verify(Execute(ctx, b, cmd, []string{"--help"}, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail() With(t).Verify(b.String()).Will(EqualTo(` @@ -93,24 +115,110 @@ Flags: t.Run("preRun called for command chain", func(t *testing.T) { ctx := context.Background() - sub2 := MustNew("sub2", "desc", "long desc", &ActionWithConfig{}, []PreRunHook{&PreRunHookWithConfig{}}) - sub1 := MustNew("sub1", "desc", "long desc", &ActionWithConfig{}, []PreRunHook{&PreRunHookWithConfig{}}, sub2) - root := MustNew("cmd", "desc", "long desc", &ActionWithConfig{}, []PreRunHook{&PreRunHookWithConfig{}}, sub1) + sub2 := MustNew("sub2", "desc", "long desc", &ActionWithConfig{}, []PreRunHook{&PreRunHookWithConfig{}}, nil) + sub1 := MustNew("sub1", "desc", "long desc", nil, []PreRunHook{&PreRunHookWithConfig{}}, nil, sub2) + root := MustNew("cmd", "desc", "long desc", nil, []PreRunHook{&PreRunHookWithConfig{}}, nil, sub1) With(t).Verify(Execute(ctx, os.Stderr, root, []string{"sub1", "sub2"}, nil)).Will(EqualTo(ExitCodeSuccess)).OrFail() - sub2PreRunTime := sub2.preRunHooks[0].(*PreRunHookWithConfig).callTime - With(t).Verify(sub2PreRunTime).Will(Not(BeNil())).OrFail() + rootPreRunHook := root.preRunHooks[0].(*PreRunHookWithConfig) + sub1PreRunHook := sub1.preRunHooks[0].(*PreRunHookWithConfig) + sub2PreRunHook := sub2.preRunHooks[0].(*PreRunHookWithConfig) + sub2Action := sub2.action.(*ActionWithConfig) + + With(t).Verify(rootPreRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(rootPreRunHook.callTime.Before(*sub1PreRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub1PreRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub1PreRunHook.callTime.Before(*sub2PreRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub2PreRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2PreRunHook.callTime.Before(*sub2Action.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub2Action.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2Action.callTime.After(*sub2PreRunHook.callTime)).Will(EqualTo(true)).OrFail() + }) + + t.Run("preRun failure stops execution", func(t *testing.T) { + failingPreHook := &PreRunHookWithConfig{TrackingPreRunHook: TrackingPreRunHook{errorToReturnOnCall: fmt.Errorf("fail")}} + passThroughPreHook := func() PreRunHook { return &PreRunHookWithConfig{} } + + ctx := context.Background() + sub2 := MustNew("sub2", "desc", "long desc", &ActionWithConfig{}, []PreRunHook{passThroughPreHook()}, nil) + sub1 := MustNew("sub1", "desc", "long desc", nil, []PreRunHook{passThroughPreHook(), failingPreHook}, nil, sub2) + root := MustNew("cmd", "desc", "long desc", nil, []PreRunHook{passThroughPreHook()}, nil, sub1) + + rootPreRunHook := root.preRunHooks[0].(*PreRunHookWithConfig) + sub1PreRunHook := sub1.preRunHooks[0].(*PreRunHookWithConfig) + sub2PreRunHook := sub2.preRunHooks[0].(*PreRunHookWithConfig) + sub2Action := sub2.action.(*ActionWithConfig) + + With(t).Verify(Execute(ctx, os.Stderr, root, []string{"sub1", "sub2"}, nil)).Will(EqualTo(ExitCodeError)).OrFail() + With(t).Verify(rootPreRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(rootPreRunHook.callTime.Before(*sub1PreRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub1PreRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2PreRunHook.callTime).Will(BeNil()).OrFail() + With(t).Verify(sub2Action.callTime).Will(BeNil()).OrFail() + }) - sub1PreRunTime := sub1.preRunHooks[0].(*PreRunHookWithConfig).callTime - With(t).Verify(sub1PreRunTime).Will(Not(BeNil())).OrFail() - With(t).Verify(sub1PreRunTime.Before(*sub2PreRunTime)).Will(EqualTo(true)).OrFail() + t.Run("postRun called for command chain", func(t *testing.T) { + ctx := context.Background() + sub2 := MustNew("sub2", "desc", "long desc", &ActionWithConfig{}, nil, []PostRunHook{&PostRunHookWithConfig{}}) + sub1 := MustNew("sub1", "desc", "long desc", nil, nil, []PostRunHook{&PostRunHookWithConfig{}}, sub2) + root := MustNew("cmd", "desc", "long desc", nil, nil, []PostRunHook{&PostRunHookWithConfig{}}, sub1) + + exitCode := Execute(ctx, os.Stderr, root, []string{"sub1", "sub2"}, nil) + With(t).Verify(exitCode).Will(EqualTo(ExitCodeSuccess)).OrFail() + + rootPostRunHook := root.postRunHooks[0].(*PostRunHookWithConfig) + sub1PostRunHook := sub1.postRunHooks[0].(*PostRunHookWithConfig) + sub2PostRunHook := sub2.postRunHooks[0].(*PostRunHookWithConfig) + sub2Action := sub2.action.(*ActionWithConfig) + + With(t).Verify(sub2Action.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2Action.callTime.Before(*sub2PostRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub2PostRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2PostRunHook.callTime.Before(*sub1PostRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub2PostRunHook.providedActionError).Will(EqualTo(sub2Action.errorToReturnOnCall)).OrFail() + With(t).Verify(sub2PostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail() + With(t).Verify(sub1PostRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub1PostRunHook.callTime.Before(*rootPostRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub1PostRunHook.providedActionError).Will(EqualTo(sub2PostRunHook.errorToReturnOnCall)).OrFail() + With(t).Verify(sub1PostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail() + With(t).Verify(rootPostRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(rootPostRunHook.providedActionError).Will(BeNil()).OrFail() + With(t).Verify(rootPostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail() + }) - rootPreRunTime := root.preRunHooks[0].(*PreRunHookWithConfig).callTime - With(t).Verify(rootPreRunTime).Will(Not(BeNil())).OrFail() - With(t).Verify(rootPreRunTime.Before(*sub1PreRunTime)).Will(EqualTo(true)).OrFail() + t.Run("postRun chain called in full, even on action or hook error", func(t *testing.T) { + failingPostHook := func() PostRunHook { + return &PostRunHookWithConfig{TrackingPostRunHook: TrackingPostRunHook{errorToReturnOnCall: fmt.Errorf("failing post hook")}} + } + passThroughPostHook := func() PostRunHook { return &PostRunHookWithConfig{} } + failingAction := &ActionWithConfig{TrackingAction: TrackingAction{errorToReturnOnCall: fmt.Errorf("failing action")}} - sub2RunTime := sub2.action.(*ActionWithConfig).callTime - With(t).Verify(sub2RunTime).Will(Not(BeNil())).OrFail() - With(t).Verify(sub2RunTime.After(*sub2PreRunTime)).Will(EqualTo(true)).OrFail() + ctx := context.Background() + sub2 := MustNew("sub2", "desc", "long desc", failingAction, nil, []PostRunHook{failingPostHook()}) + sub1 := MustNew("sub1", "desc", "long desc", nil, nil, []PostRunHook{passThroughPostHook()}, sub2) + root := MustNew("cmd", "desc", "long desc", nil, nil, []PostRunHook{passThroughPostHook()}, sub1) + + exitCode := Execute(ctx, os.Stderr, root, []string{"sub1", "sub2"}, nil) + With(t).Verify(exitCode).Will(EqualTo(ExitCodeError)).OrFail() + + rootPostRunHook := root.postRunHooks[0].(*PostRunHookWithConfig) + sub1PostRunHook := sub1.postRunHooks[0].(*PostRunHookWithConfig) + sub2PostRunHook := sub2.postRunHooks[0].(*PostRunHookWithConfig) + sub2Action := sub2.action.(*ActionWithConfig) + + With(t).Verify(sub2Action.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2Action.callTime.Before(*sub2PostRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub2PostRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub2PostRunHook.callTime.Before(*sub1PostRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub2PostRunHook.providedActionError).Will(EqualTo(sub2Action.errorToReturnOnCall, cmpopts.EquateErrors())).OrFail() + With(t).Verify(sub2PostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail() + With(t).Verify(sub1PostRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(sub1PostRunHook.callTime.Before(*rootPostRunHook.callTime)).Will(EqualTo(true)).OrFail() + With(t).Verify(sub1PostRunHook.providedActionError).Will(EqualTo(sub2Action.errorToReturnOnCall, cmpopts.EquateErrors())).OrFail() + With(t).Verify(sub1PostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail() + With(t).Verify(rootPostRunHook.callTime).Will(Not(BeNil())).OrFail() + With(t).Verify(rootPostRunHook.providedActionError).Will(EqualTo(sub2Action.errorToReturnOnCall, cmpopts.EquateErrors())).OrFail() + With(t).Verify(rootPostRunHook.providedExitCode).Will(EqualTo(exitCode)).OrFail() }) + }