Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ func ExecuteWithContext(ctx context.Context, w io.Writer, root *Command, args []
// Ensure we invoke post-run hooks before we return
chain := cmd.getChain()
defer func() {
postHooksCtx := context.Background()
for i := len(chain) - 1; i >= 0; i-- {
c := chain[i]
for j := len(c.postRunHooks) - 1; j >= 0; j-- {
h := c.postRunHooks[j]
if err := h.PostRun(ctx, actionError, exitCode); err != nil {
if err := h.PostRun(postHooksCtx, actionError, exitCode); err != nil {
_, _ = fmt.Fprintln(w, err)
exitCode = ExitCodeError
}
Expand Down
28 changes: 26 additions & 2 deletions execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import (

type TrackingAction struct {
callTime *time.Time
providedCtx context.Context
errorToReturnOnCall error
}

func (a *TrackingAction) Run(_ context.Context) error {
func (a *TrackingAction) Run(ctx context.Context) error {
a.callTime = ptrOf(time.Now())
a.providedCtx = ctx
time.Sleep(100 * time.Millisecond)
return a.errorToReturnOnCall
}
Expand All @@ -36,13 +38,15 @@ func (a *TrackingPreRunHook) PreRun(_ context.Context) error {

type TrackingPostRunHook struct {
callTime *time.Time
providedCtx context.Context
providedActionError error
providedExitCode ExitCode
errorToReturnOnCall error
}

func (a *TrackingPostRunHook) PostRun(_ context.Context, actionError error, exitCode ExitCode) error {
func (a *TrackingPostRunHook) PostRun(ctx context.Context, actionError error, exitCode ExitCode) error {
a.callTime = ptrOf(time.Now())
a.providedCtx = ctx
a.providedActionError = actionError
a.providedExitCode = exitCode
time.Sleep(100 * time.Millisecond)
Expand Down Expand Up @@ -252,4 +256,24 @@ Flags:
With(t).Verify(action.TrackingAction.callTime).Will(Not(BeNil())).OrFail()
With(t).Verify(b.String()).Will(BeEmpty()).OrFail()
})

t.Run("ensure post-hooks use fresh context", func(t *testing.T) {
//nolint:all
executionCtx := context.WithValue(context.Background(), "k", "v")

action := &TrackingAction{}
root := MustNew("cmd", "desc", "long desc", action, []any{&PostRunHookWithConfig{}})

exitCode := ExecuteWithContext(executionCtx, os.Stderr, root, nil, nil)
With(t).Verify(exitCode).Will(EqualTo(ExitCodeSuccess)).OrFail()

if action.providedCtx != executionCtx {
t.Fatalf("incorrect context passed to action: %+v", action.providedCtx)
}

rootPostRunHook := root.postRunHooks[0].(*PostRunHookWithConfig)
if rootPostRunHook.providedCtx == executionCtx {
t.Fatalf("incorrect context passed to posthook: %+v", rootPostRunHook.providedCtx)
}
})
}