diff --git a/cmd/gh-aw/main.go b/cmd/gh-aw/main.go index 80a3855aa08..9d64d7d6a12 100644 --- a/cmd/gh-aw/main.go +++ b/cmd/gh-aw/main.go @@ -1,11 +1,14 @@ package main import ( + "context" "errors" "fmt" "os" + "os/signal" "sort" "strings" + "syscall" "github.com/github/gh-aw/pkg/cli" "github.com/github/gh-aw/pkg/console" @@ -829,7 +832,12 @@ func main() { // Set release flag in the workflow package workflow.SetIsRelease(isRelease == "true") - if err := rootCmd.Execute(); err != nil { + // Set up a context that is cancelled when Ctrl-C (SIGINT) or SIGTERM is received. + // This ensures all commands and subprocesses are properly interrupted on Ctrl-C. + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := rootCmd.ExecuteContext(ctx); err != nil { errMsg := err.Error() // Check if error is already formatted to avoid double formatting: // - Contains suggestions (FormatErrorWithSuggestions) diff --git a/pkg/cli/add_interactive_auth.go b/pkg/cli/add_interactive_auth.go index f00c4bb8bea..3e82a88b917 100644 --- a/pkg/cli/add_interactive_auth.go +++ b/pkg/cli/add_interactive_auth.go @@ -58,7 +58,7 @@ func (c *AddInteractiveConfig) checkGitRepository() error { ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(c.Ctx); err != nil { return fmt.Errorf("failed to get repository info: %w", err) } diff --git a/pkg/cli/add_interactive_engine.go b/pkg/cli/add_interactive_engine.go index 7dbe6349aa9..b18732d2fbb 100644 --- a/pkg/cli/add_interactive_engine.go +++ b/pkg/cli/add_interactive_engine.go @@ -126,7 +126,7 @@ func (c *AddInteractiveConfig) selectAIEngineAndKey() error { ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(c.Ctx); err != nil { return fmt.Errorf("failed to select coding agent: %w", err) } @@ -167,6 +167,7 @@ func (c *AddInteractiveConfig) configureEngineAPISecret(engine string) error { // Use the unified checkAndEnsureEngineSecrets function config := EngineSecretConfig{ + Ctx: c.Ctx, RepoSlug: c.RepoOverride, Engine: engine, Verbose: c.Verbose, diff --git a/pkg/cli/add_interactive_orchestrator.go b/pkg/cli/add_interactive_orchestrator.go index 2e00285b1e5..124783ac312 100644 --- a/pkg/cli/add_interactive_orchestrator.go +++ b/pkg/cli/add_interactive_orchestrator.go @@ -17,6 +17,7 @@ var addInteractiveLog = logger.New("cli:add_interactive") // AddInteractiveConfig holds configuration for interactive add mode type AddInteractiveConfig struct { + Ctx context.Context // Context for cancellation (Ctrl-C handling) WorkflowSpecs []string Verbose bool EngineOverride string @@ -71,6 +72,7 @@ func RunAddInteractive(ctx context.Context, workflowSpecs []string, verbose bool } config := &AddInteractiveConfig{ + Ctx: ctx, WorkflowSpecs: workflowSpecs, Verbose: verbose, EngineOverride: engineOverride, @@ -231,7 +233,7 @@ func (c *AddInteractiveConfig) confirmChanges(workflowFiles, initFiles []string, ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(c.Ctx); err != nil { return fmt.Errorf("confirmation failed: %w", err) } diff --git a/pkg/cli/add_interactive_schedule.go b/pkg/cli/add_interactive_schedule.go index 07cdafc4327..0f98ed15573 100644 --- a/pkg/cli/add_interactive_schedule.go +++ b/pkg/cli/add_interactive_schedule.go @@ -230,7 +230,7 @@ func (c *AddInteractiveConfig) selectScheduleFrequency() error { ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(c.Ctx); err != nil { return fmt.Errorf("failed to select schedule frequency: %w", err) } diff --git a/pkg/cli/add_interactive_workflow.go b/pkg/cli/add_interactive_workflow.go index a2322601b49..272dabab501 100644 --- a/pkg/cli/add_interactive_workflow.go +++ b/pkg/cli/add_interactive_workflow.go @@ -112,7 +112,7 @@ func (c *AddInteractiveConfig) checkStatusAndOfferRun(ctx context.Context) error ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(ctx); err != nil { return nil // Not critical, just skip } diff --git a/pkg/cli/engine_secrets.go b/pkg/cli/engine_secrets.go index f2989b331b2..8a6651d0513 100644 --- a/pkg/cli/engine_secrets.go +++ b/pkg/cli/engine_secrets.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -34,6 +35,8 @@ type SecretRequirement struct { // EngineSecretConfig contains configuration for engine secret collection operations type EngineSecretConfig struct { + // Ctx is the context for cancellation (optional, but recommended for proper Ctrl-C handling) + Ctx context.Context // RepoSlug is the repository slug to check for existing secrets (optional) RepoSlug string // Engine is the engine type to collect secrets for (e.g., "copilot", "claude", "codex") @@ -171,6 +174,14 @@ func getMissingRequiredSecrets(requirements []SecretRequirement, existingSecrets return missing } +// ctx returns the context from the config, defaulting to background if nil +func (c EngineSecretConfig) ctx() context.Context { + if c.Ctx != nil { + return c.Ctx + } + return context.Background() +} + // checkAndEnsureEngineSecretsForEngine is the unified entry point for checking and collecting engine secrets. // It checks existing secrets in the repository and environment, and prompts for missing ones. func checkAndEnsureEngineSecretsForEngine(config EngineSecretConfig) error { @@ -310,7 +321,7 @@ func promptForCopilotPATUnified(req SecretRequirement, config EngineSecretConfig ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(config.ctx()); err != nil { return fmt.Errorf("failed to get Copilot token: %w", err) } @@ -358,7 +369,7 @@ func promptForSystemTokenUnified(req SecretRequirement, config EngineSecretConfi ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(config.ctx()); err != nil { return fmt.Errorf("failed to get %s token: %w", req.Name, err) } @@ -411,7 +422,7 @@ func promptForGenericAPIKeyUnified(req SecretRequirement, config EngineSecretCon ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(config.ctx()); err != nil { return fmt.Errorf("failed to get %s API key: %w", label, err) } diff --git a/pkg/cli/interactive.go b/pkg/cli/interactive.go index ad40a7b5df4..4ea0fe6cfd9 100644 --- a/pkg/cli/interactive.go +++ b/pkg/cli/interactive.go @@ -36,6 +36,7 @@ var commonWorkflowNames = []string{ // InteractiveWorkflowBuilder collects user input to build an agentic workflow type InteractiveWorkflowBuilder struct { + ctx context.Context WorkflowName string Trigger string Engine string @@ -60,6 +61,7 @@ func CreateWorkflowInteractively(ctx context.Context, workflowName string, verbo } builder := &InteractiveWorkflowBuilder{ + ctx: ctx, WorkflowName: workflowName, } @@ -101,7 +103,7 @@ func (b *InteractiveWorkflowBuilder) promptForWorkflowName() error { ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - return form.Run() + return form.RunWithContext(b.ctx) } // promptForConfiguration organizes all prompts into logical groups with titles and descriptions @@ -225,7 +227,7 @@ func (b *InteractiveWorkflowBuilder) promptForConfiguration() error { Description("Describe what you want this workflow to accomplish"), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(b.ctx); err != nil { return err } diff --git a/pkg/cli/pr_automerge.go b/pkg/cli/pr_automerge.go index 8a0468a9dd0..4c4cf796d02 100644 --- a/pkg/cli/pr_automerge.go +++ b/pkg/cli/pr_automerge.go @@ -1,6 +1,7 @@ package cli import ( + "context" "encoding/json" "errors" "fmt" @@ -116,18 +117,19 @@ func AutoMergePullRequestsLegacy(repoSlug string, verbose bool) error { } // WaitForWorkflowCompletion waits for a workflow run to complete, with a specified timeout -func WaitForWorkflowCompletion(repoSlug, runID string, timeoutMinutes int, verbose bool) error { +func WaitForWorkflowCompletion(ctx context.Context, repoSlug, runID string, timeoutMinutes int, verbose bool) error { prAutomergeLog.Printf("Waiting for workflow completion: repo=%s, runID=%s, timeout=%d minutes", repoSlug, runID, timeoutMinutes) timeout := time.Duration(timeoutMinutes) * time.Minute return PollWithSignalHandling(PollOptions{ + Ctx: ctx, PollInterval: 10 * time.Second, Timeout: timeout, - PollFunc: func() (PollResult, error) { - // Check workflow status - output, err := workflow.RunGH("Checking workflow status...", "run", "view", runID, "--repo", repoSlug, "--json", "status,conclusion") - + PollFunc: func(ctx context.Context) (PollResult, error) { + // Check workflow status with context-aware GH execution. + // ctx is cancelled on Ctrl-C, which causes RunGHContext to abort the gh subprocess. + output, err := workflow.RunGHContext(ctx, "Checking workflow status...", "run", "view", runID, "--repo", repoSlug, "--json", "status,conclusion") if err != nil { return PollFailure, fmt.Errorf("failed to check workflow status: %w", err) } diff --git a/pkg/cli/pr_automerge_test.go b/pkg/cli/pr_automerge_test.go index 388c09b2d07..b5651c596e6 100644 --- a/pkg/cli/pr_automerge_test.go +++ b/pkg/cli/pr_automerge_test.go @@ -3,6 +3,8 @@ package cli import ( + "context" + "errors" "testing" ) @@ -16,10 +18,33 @@ func TestWaitForWorkflowCompletionUsesSignalHandling(t *testing.T) { // but we can verify that the timeout mechanism works, which confirms // it's using the polling helper - err := WaitForWorkflowCompletion("nonexistent/repo", "12345", 0, false) + err := WaitForWorkflowCompletion(context.Background(), "nonexistent/repo", "12345", 0, false) // Should timeout or fail to check workflow status if err == nil { t.Error("Expected error for nonexistent workflow, got nil") } } + +// TestWaitForWorkflowCompletion_ContextCancellation verifies that WaitForWorkflowCompletion +// propagates cancellation when the context is cancelled, so callers (e.g. the repeat loop) +// can detect an intentional interruption and stop immediately. +func TestWaitForWorkflowCompletion_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + // Cancel immediately so the poll loop exits on the first ctx.Done() check. + cancel() + + err := WaitForWorkflowCompletion(ctx, "nonexistent/repo", "12345", 5, false) + + if err == nil { + t.Fatal("Expected error on cancelled context, got nil") + } + + // Must be either ErrInterrupted (from the poll select loop) or context.Canceled + // (from the PollFunc guard when ctx is already cancelled before polling begins). + // Both indicate an intentional interruption that callers should detect and propagate. + isInterruption := errors.Is(err, ErrInterrupted) || errors.Is(err, context.Canceled) + if !isInterruption { + t.Errorf("Expected interruption error (ErrInterrupted or context.Canceled) from WaitForWorkflowCompletion, got: %v", err) + } +} diff --git a/pkg/cli/retry.go b/pkg/cli/retry.go index b0404bede77..c33ac192d3f 100644 --- a/pkg/cli/retry.go +++ b/pkg/cli/retry.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "os" "os/signal" @@ -16,6 +17,8 @@ var retryLog = logger.New("cli:retry") // RepeatOptions contains configuration for the repeat functionality type RepeatOptions struct { + // Context for cancellation (optional, but recommended for proper Ctrl-C handling) + Ctx context.Context // Number of times to repeat execution (0 = run once) RepeatCount int // Message to display when starting repeat mode @@ -60,25 +63,41 @@ func ExecuteWithRepeat(options RepeatOptions) error { } fmt.Fprintln(output, console.FormatInfoMessage(startMsg)) + // Use provided context or fall back to background context + ctx := options.Ctx + if ctx == nil { + ctx = context.Background() + } + // Set up signal handling for graceful shutdown + // Signal channel provides a fallback when no context is provided or for direct OS signals sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) defer signal.Stop(sigChan) + // runCleanup executes the optional cleanup function + runCleanup := func() { + if options.CleanupFunc != nil { + retryLog.Print("Executing cleanup function") + options.CleanupFunc() + } + } + // Run the specified number of additional times for i := 1; i <= options.RepeatCount; i++ { select { + case <-ctx.Done(): + retryLog.Printf("Context cancelled at iteration %d/%d", i, options.RepeatCount) + fmt.Fprintln(output, console.FormatInfoMessage("Received interrupt signal, stopping repeat...")) + runCleanup() + return ctx.Err() + case <-sigChan: retryLog.Printf("Interrupt signal received at iteration %d/%d", i, options.RepeatCount) fmt.Fprintln(output, console.FormatInfoMessage("Received interrupt signal, stopping repeat...")) + runCleanup() + return context.Canceled - // Execute cleanup function if provided - if options.CleanupFunc != nil { - retryLog.Print("Executing cleanup function") - options.CleanupFunc() - } - - return nil default: retryLog.Printf("Starting iteration %d/%d", i, options.RepeatCount) // Use provided repeat message or default diff --git a/pkg/cli/run_interactive.go b/pkg/cli/run_interactive.go index 59193cb87f8..c0131844807 100644 --- a/pkg/cli/run_interactive.go +++ b/pkg/cli/run_interactive.go @@ -52,7 +52,7 @@ func RunWorkflowInteractively(ctx context.Context, verbose bool, repoOverride st } // Step 2: Let user select a workflow - selectedWorkflow, err := selectWorkflow(workflows) + selectedWorkflow, err := selectWorkflow(ctx, workflows) if err != nil { return fmt.Errorf("workflow selection cancelled or failed: %w", err) } @@ -63,13 +63,13 @@ func RunWorkflowInteractively(ctx context.Context, verbose bool, repoOverride st showWorkflowInfo(selectedWorkflow) // Step 4: Collect workflow inputs if needed - inputValues, err := collectWorkflowInputs(selectedWorkflow) + inputValues, err := collectWorkflowInputs(ctx, selectedWorkflow) if err != nil { return fmt.Errorf("failed to collect workflow inputs: %w", err) } // Step 5: Confirm execution - if !confirmExecution(selectedWorkflow, inputValues) { + if !confirmExecution(ctx, selectedWorkflow, inputValues) { fmt.Fprintln(os.Stderr, console.FormatWarningMessage("Workflow execution cancelled")) return nil } @@ -168,7 +168,7 @@ func buildWorkflowDescription(inputs map[string]*workflow.InputDefinition) strin } // selectWorkflow displays an interactive list for workflow selection with fuzzy search -func selectWorkflow(workflows []WorkflowOption) (*WorkflowOption, error) { +func selectWorkflow(ctx context.Context, workflows []WorkflowOption) (*WorkflowOption, error) { runInteractiveLog.Printf("Displaying workflow selection: %d workflows", len(workflows)) // Check if we're in a TTY environment @@ -192,7 +192,7 @@ func selectWorkflow(workflows []WorkflowOption) (*WorkflowOption, error) { ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(ctx); err != nil { return nil, fmt.Errorf("workflow selection cancelled or failed: %w", err) } @@ -258,17 +258,17 @@ func showWorkflowInfo(wf *WorkflowOption) { } // collectWorkflowInputs collects input values from the user -func collectWorkflowInputs(wf *WorkflowOption) ([]string, error) { +func collectWorkflowInputs(ctx context.Context, wf *WorkflowOption) ([]string, error) { if len(wf.Inputs) == 0 { return nil, nil } runInteractiveLog.Printf("Collecting %d workflow inputs", len(wf.Inputs)) - return collectInputsWithMap(wf.Inputs) + return collectInputsWithMap(ctx, wf.Inputs) } // collectInputsWithMap collects inputs using a map to properly capture values -func collectInputsWithMap(inputs map[string]*workflow.InputDefinition) ([]string, error) { +func collectInputsWithMap(ctx context.Context, inputs map[string]*workflow.InputDefinition) ([]string, error) { // Create a map to store string values for the form inputValues := make(map[string]string) // Create a map to track the string pointers we'll pass to huh @@ -315,7 +315,7 @@ func collectInputsWithMap(inputs map[string]*workflow.InputDefinition) ([]string // Show the form form := huh.NewForm(formGroups...).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(ctx); err != nil { return nil, fmt.Errorf("input collection cancelled: %w", err) } @@ -333,7 +333,7 @@ func collectInputsWithMap(inputs map[string]*workflow.InputDefinition) ([]string } // confirmExecution asks the user to confirm workflow execution -func confirmExecution(wf *WorkflowOption, inputs []string) bool { +func confirmExecution(ctx context.Context, wf *WorkflowOption, inputs []string) bool { runInteractiveLog.Print("Requesting execution confirmation") var confirm bool @@ -353,7 +353,7 @@ func confirmExecution(wf *WorkflowOption, inputs []string) bool { ), ).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode()) - if err := form.Run(); err != nil { + if err := form.RunWithContext(ctx); err != nil { runInteractiveLog.Printf("Confirmation failed: %v", err) return false } @@ -399,13 +399,13 @@ func RunSpecificWorkflowInteractively(ctx context.Context, workflowName string, } // Collect workflow inputs if needed - inputValues, err := collectWorkflowInputs(wf) + inputValues, err := collectWorkflowInputs(ctx, wf) if err != nil { return fmt.Errorf("failed to collect workflow inputs: %w", err) } // Confirm execution (skip if no inputs were collected - user already confirmed they want to run) - if len(inputValues) > 0 && !confirmExecution(wf, inputValues) { + if len(inputValues) > 0 && !confirmExecution(ctx, wf, inputValues) { fmt.Fprintln(os.Stderr, console.FormatWarningMessage("Workflow execution cancelled")) return nil } diff --git a/pkg/cli/run_workflow_execution.go b/pkg/cli/run_workflow_execution.go index 36329e51550..0c2fbc75758 100644 --- a/pkg/cli/run_workflow_execution.go +++ b/pkg/cli/run_workflow_execution.go @@ -353,7 +353,7 @@ func RunWorkflowOnGitHub(ctx context.Context, workflowIdOrName string, opts RunO } // Execute gh workflow run command and capture output - cmd := workflow.ExecGH(args...) + cmd := workflow.ExecGHContext(ctx, args...) if opts.Verbose { var cmdParts []string @@ -460,7 +460,11 @@ func RunWorkflowOnGitHub(ctx context.Context, workflowIdOrName string, opts RunO } runIDStr := strconv.FormatInt(runInfo.DatabaseID, 10) - if err := WaitForWorkflowCompletion(targetRepo, runIDStr, 30, opts.Verbose); err != nil { + if err := WaitForWorkflowCompletion(ctx, targetRepo, runIDStr, 30, opts.Verbose); err != nil { + // Propagate interrupts/cancellation so the caller (repeat loop) can stop + if ctx.Err() != nil || errors.Is(err, ErrInterrupted) { + return err + } if opts.AutoMergePRs { fmt.Fprintln(os.Stderr, console.FormatWarningMessage(fmt.Sprintf("Workflow did not complete successfully, skipping auto-merge: %v", err))) } else { @@ -610,6 +614,7 @@ func RunWorkflowsOnGitHub(ctx context.Context, workflowNames []string, opts RunO // Execute workflows with optional repeat functionality return ExecuteWithRepeat(RepeatOptions{ + Ctx: ctx, RepeatCount: opts.RepeatCount, RepeatMessage: "Repeating workflow run", ExecuteFunc: runAllWorkflows, diff --git a/pkg/cli/signal_aware_poll.go b/pkg/cli/signal_aware_poll.go index 101a499f009..7c4b5ebf714 100644 --- a/pkg/cli/signal_aware_poll.go +++ b/pkg/cli/signal_aware_poll.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -14,6 +15,9 @@ import ( var pollLog = logger.New("cli:signal_aware_poll") +// ErrInterrupted is returned when polling is interrupted by a signal or context cancellation +var ErrInterrupted = errors.New("interrupted by user") + // PollResult represents the result of a polling operation type PollResult int @@ -28,13 +32,17 @@ const ( // PollOptions contains configuration for signal-aware polling type PollOptions struct { + // Context for cancellation (optional, but recommended for proper Ctrl-C handling) + Ctx context.Context // Interval between poll attempts PollInterval time.Duration // Timeout for the entire polling operation Timeout time.Duration - // Function to call on each poll iteration - // Should return PollContinue to keep polling, PollSuccess to succeed, or PollFailure to fail - PollFunc func() (PollResult, error) + // Function to call on each poll iteration. + // The ctx passed to PollFunc is the same context used by the poll loop, so callers can + // pass it to context-aware operations (e.g. RunGHContext) to abort mid-call on Ctrl-C. + // Should return PollContinue to keep polling, PollSuccess to succeed, or PollFailure to fail. + PollFunc func(ctx context.Context) (PollResult, error) // Message to display when polling starts (optional) StartMessage string // Message to display on each poll iteration (optional) @@ -54,7 +62,14 @@ func PollWithSignalHandling(options PollOptions) error { fmt.Fprintln(os.Stderr, console.FormatInfoMessage(options.StartMessage)) } + // Use provided context or fall back to background context + ctx := options.Ctx + if ctx == nil { + ctx = context.Background() + } + // Set up signal handling for graceful shutdown + // Signal channel provides a fallback when no context is provided or for direct OS signals sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) defer signal.Stop(sigChan) @@ -65,7 +80,7 @@ func PollWithSignalHandling(options PollOptions) error { defer ticker.Stop() // Perform initial check immediately - result, err := options.PollFunc() + result, err := options.PollFunc(ctx) switch result { case PollSuccess: if options.Verbose && options.SuccessMessage != "" { @@ -79,10 +94,19 @@ func PollWithSignalHandling(options PollOptions) error { // Continue polling for { select { + case <-ctx.Done(): + pollLog.Printf("Context cancelled (%v), stopping poll", ctx.Err()) + msg := "Operation cancelled, stopping wait..." + if err := ctx.Err(); err != nil { + msg = fmt.Sprintf("Operation cancelled (%v), stopping wait...", err) + } + fmt.Fprintln(os.Stderr, console.FormatInfoMessage(msg)) + return ErrInterrupted + case <-sigChan: pollLog.Print("Received interrupt signal") fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Received interrupt signal, stopping wait...")) - return errors.New("interrupted by user") + return ErrInterrupted case <-ticker.C: // Check if timeout exceeded @@ -92,7 +116,7 @@ func PollWithSignalHandling(options PollOptions) error { } // Poll for status - result, err := options.PollFunc() + result, err := options.PollFunc(ctx) switch result { case PollSuccess: diff --git a/pkg/cli/signal_aware_poll_test.go b/pkg/cli/signal_aware_poll_test.go index 44dac376f3b..da9c4713a1b 100644 --- a/pkg/cli/signal_aware_poll_test.go +++ b/pkg/cli/signal_aware_poll_test.go @@ -3,6 +3,7 @@ package cli import ( + "context" "errors" "testing" "time" @@ -13,7 +14,7 @@ func TestPollWithSignalHandling_Success(t *testing.T) { err := PollWithSignalHandling(PollOptions{ PollInterval: 10 * time.Millisecond, Timeout: 1 * time.Second, - PollFunc: func() (PollResult, error) { + PollFunc: func(_ context.Context) (PollResult, error) { callCount++ if callCount >= 3 { return PollSuccess, nil @@ -37,7 +38,7 @@ func TestPollWithSignalHandling_Failure(t *testing.T) { err := PollWithSignalHandling(PollOptions{ PollInterval: 10 * time.Millisecond, Timeout: 1 * time.Second, - PollFunc: func() (PollResult, error) { + PollFunc: func(_ context.Context) (PollResult, error) { return PollFailure, expectedErr }, Verbose: false, @@ -56,7 +57,7 @@ func TestPollWithSignalHandling_Timeout(t *testing.T) { err := PollWithSignalHandling(PollOptions{ PollInterval: 50 * time.Millisecond, Timeout: 100 * time.Millisecond, - PollFunc: func() (PollResult, error) { + PollFunc: func(_ context.Context) (PollResult, error) { return PollContinue, nil }, Verbose: false, @@ -76,7 +77,7 @@ func TestPollWithSignalHandling_ImmediateSuccess(t *testing.T) { err := PollWithSignalHandling(PollOptions{ PollInterval: 10 * time.Millisecond, Timeout: 1 * time.Second, - PollFunc: func() (PollResult, error) { + PollFunc: func(_ context.Context) (PollResult, error) { callCount++ return PollSuccess, nil }, @@ -104,3 +105,60 @@ func TestPollWithSignalHandling_SignalInterruption(t *testing.T) { // This test just verifies the structure is correct t.Skip("Signal interruption requires manual testing - implementation verified by code review") } + +// TestPollWithSignalHandling_ContextCancellation verifies that PollWithSignalHandling +// returns ErrInterrupted when the context is cancelled, enabling proper Ctrl-C propagation. +func TestPollWithSignalHandling_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + pollStarted := make(chan struct{}) + err := func() error { + // Cancel context after poll loop starts its first wait + go func() { + <-pollStarted + cancel() + }() + return PollWithSignalHandling(PollOptions{ + Ctx: ctx, + PollInterval: 50 * time.Millisecond, + Timeout: 5 * time.Second, + PollFunc: func(_ context.Context) (PollResult, error) { + // Signal that the poll loop is running, then keep returning Continue + select { + case <-pollStarted: + default: + close(pollStarted) + } + return PollContinue, nil + }, + Verbose: false, + }) + }() + + if !errors.Is(err, ErrInterrupted) { + t.Errorf("Expected ErrInterrupted on context cancellation, got: %v", err) + } +} + +// TestPollWithSignalHandling_AlreadyCancelledContext verifies that PollWithSignalHandling +// returns ErrInterrupted immediately when given an already-cancelled context. +func TestPollWithSignalHandling_AlreadyCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel before starting + + // The initial PollFunc call might succeed or return Continue, but the + // next select iteration should detect ctx.Done() and return ErrInterrupted. + err := PollWithSignalHandling(PollOptions{ + Ctx: ctx, + PollInterval: 10 * time.Millisecond, + Timeout: 5 * time.Second, + PollFunc: func(_ context.Context) (PollResult, error) { + return PollContinue, nil + }, + Verbose: false, + }) + + if !errors.Is(err, ErrInterrupted) { + t.Errorf("Expected ErrInterrupted for already-cancelled context, got: %v", err) + } +} diff --git a/pkg/cli/trial_command.go b/pkg/cli/trial_command.go index abe2f67cc27..f0c8fd3c0d9 100644 --- a/pkg/cli/trial_command.go +++ b/pkg/cli/trial_command.go @@ -335,6 +335,7 @@ func RunWorkflowTrials(ctx context.Context, workflowSpecs []string, opts TrialOp // Ensure the required engine secret is available (prompts interactively if needed) secretConfig := EngineSecretConfig{ + Ctx: ctx, RepoSlug: hostRepoSlug, Engine: opts.EngineOverride, Verbose: opts.Verbose, @@ -473,7 +474,14 @@ func RunWorkflowTrials(ctx context.Context, workflowSpecs []string, opts TrialOp fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Workflow run started with ID: %s (%s)", runID, workflowRunURL))) // Wait for workflow completion - if err := WaitForWorkflowCompletion(hostRepoSlug, runID, opts.TimeoutMinutes, opts.Verbose); err != nil { + if err := WaitForWorkflowCompletion(ctx, hostRepoSlug, runID, opts.TimeoutMinutes, opts.Verbose); err != nil { + // If the context was canceled or its deadline was exceeded, return that directly + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return err + } return fmt.Errorf("workflow '%s' execution failed or timed out: %w", parsedSpec.WorkflowName, err) }