From 1a9d807fd6415c4c598eadcd8e70b088f288e5f2 Mon Sep 17 00:00:00 2001 From: Arik Kfir Date: Fri, 12 Jul 2024 22:46:35 +0300 Subject: [PATCH] fix(execute): use fresh context for post-run hooks This change ensures that command post-run hooks receive a fresh context instead of the original context passed to the `Execute` functions. This is needed since often the context passed to the `Execute` functions is canceled by the time the post-run hooks are ran - either by the OS signal, deadlines, or other reasons. --- execute.go | 3 ++- execute_test.go | 28 ++++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/execute.go b/execute.go index bfb9575..c3e1eb2 100644 --- a/execute.go +++ b/execute.go @@ -60,11 +60,12 @@ func ExecuteWithContext(ctx context.Context, w io.Writer, root *Command, args [] // Ensure we invoke post-run hooks before we return chain := cmd.getChain() defer func() { + postHooksCtx := context.Background() for i := len(chain) - 1; i >= 0; i-- { c := chain[i] for j := len(c.postRunHooks) - 1; j >= 0; j-- { h := c.postRunHooks[j] - if err := h.PostRun(ctx, actionError, exitCode); err != nil { + if err := h.PostRun(postHooksCtx, actionError, exitCode); err != nil { _, _ = fmt.Fprintln(w, err) exitCode = ExitCodeError } diff --git a/execute_test.go b/execute_test.go index cdaee58..b50dc57 100644 --- a/execute_test.go +++ b/execute_test.go @@ -14,11 +14,13 @@ import ( type TrackingAction struct { callTime *time.Time + providedCtx context.Context errorToReturnOnCall error } -func (a *TrackingAction) Run(_ context.Context) error { +func (a *TrackingAction) Run(ctx context.Context) error { a.callTime = ptrOf(time.Now()) + a.providedCtx = ctx time.Sleep(100 * time.Millisecond) return a.errorToReturnOnCall } @@ -36,13 +38,15 @@ func (a *TrackingPreRunHook) PreRun(_ context.Context) error { type TrackingPostRunHook struct { callTime *time.Time + providedCtx context.Context providedActionError error providedExitCode ExitCode errorToReturnOnCall error } -func (a *TrackingPostRunHook) PostRun(_ context.Context, actionError error, exitCode ExitCode) error { +func (a *TrackingPostRunHook) PostRun(ctx context.Context, actionError error, exitCode ExitCode) error { a.callTime = ptrOf(time.Now()) + a.providedCtx = ctx a.providedActionError = actionError a.providedExitCode = exitCode time.Sleep(100 * time.Millisecond) @@ -252,4 +256,24 @@ Flags: With(t).Verify(action.TrackingAction.callTime).Will(Not(BeNil())).OrFail() With(t).Verify(b.String()).Will(BeEmpty()).OrFail() }) + + t.Run("ensure post-hooks use fresh context", func(t *testing.T) { + //nolint:all + executionCtx := context.WithValue(context.Background(), "k", "v") + + action := &TrackingAction{} + root := MustNew("cmd", "desc", "long desc", action, []any{&PostRunHookWithConfig{}}) + + exitCode := ExecuteWithContext(executionCtx, os.Stderr, root, nil, nil) + With(t).Verify(exitCode).Will(EqualTo(ExitCodeSuccess)).OrFail() + + if action.providedCtx != executionCtx { + t.Fatalf("incorrect context passed to action: %+v", action.providedCtx) + } + + rootPostRunHook := root.postRunHooks[0].(*PostRunHookWithConfig) + if rootPostRunHook.providedCtx == executionCtx { + t.Fatalf("incorrect context passed to posthook: %+v", rootPostRunHook.providedCtx) + } + }) }