Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cmd/gh-aw/main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package main

import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"sort"
"strings"
"syscall"

"github.com/github/gh-aw/pkg/cli"
"github.com/github/gh-aw/pkg/console"
Expand Down Expand Up @@ -829,7 +832,12 @@ func main() {
// Set release flag in the workflow package
workflow.SetIsRelease(isRelease == "true")

if err := rootCmd.Execute(); err != nil {
// Set up a context that is cancelled when Ctrl-C (SIGINT) or SIGTERM is received.
// This ensures all commands and subprocesses are properly interrupted on Ctrl-C.
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

if err := rootCmd.ExecuteContext(ctx); err != nil {
errMsg := err.Error()
// Check if error is already formatted to avoid double formatting:
// - Contains suggestions (FormatErrorWithSuggestions)
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/add_interactive_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (c *AddInteractiveConfig) checkGitRepository() error {
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(c.Ctx); err != nil {
return fmt.Errorf("failed to get repository info: %w", err)
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/add_interactive_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (c *AddInteractiveConfig) selectAIEngineAndKey() error {
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(c.Ctx); err != nil {
return fmt.Errorf("failed to select coding agent: %w", err)
}

Expand Down Expand Up @@ -167,6 +167,7 @@ func (c *AddInteractiveConfig) configureEngineAPISecret(engine string) error {

// Use the unified checkAndEnsureEngineSecrets function
config := EngineSecretConfig{
Ctx: c.Ctx,
RepoSlug: c.RepoOverride,
Engine: engine,
Verbose: c.Verbose,
Expand Down
4 changes: 3 additions & 1 deletion pkg/cli/add_interactive_orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var addInteractiveLog = logger.New("cli:add_interactive")

// AddInteractiveConfig holds configuration for interactive add mode
type AddInteractiveConfig struct {
Ctx context.Context // Context for cancellation (Ctrl-C handling)
WorkflowSpecs []string
Verbose bool
EngineOverride string
Expand Down Expand Up @@ -71,6 +72,7 @@ func RunAddInteractive(ctx context.Context, workflowSpecs []string, verbose bool
}

config := &AddInteractiveConfig{
Ctx: ctx,
WorkflowSpecs: workflowSpecs,
Verbose: verbose,
EngineOverride: engineOverride,
Expand Down Expand Up @@ -231,7 +233,7 @@ func (c *AddInteractiveConfig) confirmChanges(workflowFiles, initFiles []string,
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(c.Ctx); err != nil {
return fmt.Errorf("confirmation failed: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/add_interactive_schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (c *AddInteractiveConfig) selectScheduleFrequency() error {
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(c.Ctx); err != nil {
return fmt.Errorf("failed to select schedule frequency: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/add_interactive_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (c *AddInteractiveConfig) checkStatusAndOfferRun(ctx context.Context) error
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(ctx); err != nil {
return nil // Not critical, just skip
}

Expand Down
17 changes: 14 additions & 3 deletions pkg/cli/engine_secrets.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -34,6 +35,8 @@ type SecretRequirement struct {

// EngineSecretConfig contains configuration for engine secret collection operations
type EngineSecretConfig struct {
// Ctx is the context for cancellation (optional, but recommended for proper Ctrl-C handling)
Ctx context.Context
// RepoSlug is the repository slug to check for existing secrets (optional)
RepoSlug string
// Engine is the engine type to collect secrets for (e.g., "copilot", "claude", "codex")
Expand Down Expand Up @@ -171,6 +174,14 @@ func getMissingRequiredSecrets(requirements []SecretRequirement, existingSecrets
return missing
}

// ctx returns the context from the config, defaulting to background if nil
func (c EngineSecretConfig) ctx() context.Context {
if c.Ctx != nil {
return c.Ctx
}
return context.Background()
}

// checkAndEnsureEngineSecretsForEngine is the unified entry point for checking and collecting engine secrets.
// It checks existing secrets in the repository and environment, and prompts for missing ones.
func checkAndEnsureEngineSecretsForEngine(config EngineSecretConfig) error {
Expand Down Expand Up @@ -310,7 +321,7 @@ func promptForCopilotPATUnified(req SecretRequirement, config EngineSecretConfig
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(config.ctx()); err != nil {
return fmt.Errorf("failed to get Copilot token: %w", err)
}

Expand Down Expand Up @@ -358,7 +369,7 @@ func promptForSystemTokenUnified(req SecretRequirement, config EngineSecretConfi
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(config.ctx()); err != nil {
return fmt.Errorf("failed to get %s token: %w", req.Name, err)
}

Expand Down Expand Up @@ -411,7 +422,7 @@ func promptForGenericAPIKeyUnified(req SecretRequirement, config EngineSecretCon
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(config.ctx()); err != nil {
return fmt.Errorf("failed to get %s API key: %w", label, err)
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/cli/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var commonWorkflowNames = []string{

// InteractiveWorkflowBuilder collects user input to build an agentic workflow
type InteractiveWorkflowBuilder struct {
ctx context.Context
WorkflowName string
Trigger string
Engine string
Expand All @@ -60,6 +61,7 @@ func CreateWorkflowInteractively(ctx context.Context, workflowName string, verbo
}

builder := &InteractiveWorkflowBuilder{
ctx: ctx,
WorkflowName: workflowName,
}

Expand Down Expand Up @@ -101,7 +103,7 @@ func (b *InteractiveWorkflowBuilder) promptForWorkflowName() error {
),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

return form.Run()
return form.RunWithContext(b.ctx)
}

// promptForConfiguration organizes all prompts into logical groups with titles and descriptions
Expand Down Expand Up @@ -225,7 +227,7 @@ func (b *InteractiveWorkflowBuilder) promptForConfiguration() error {
Description("Describe what you want this workflow to accomplish"),
).WithTheme(styles.HuhTheme()).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
if err := form.RunWithContext(b.ctx); err != nil {
return err
}

Expand Down
12 changes: 7 additions & 5 deletions pkg/cli/pr_automerge.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -116,18 +117,19 @@ func AutoMergePullRequestsLegacy(repoSlug string, verbose bool) error {
}

// WaitForWorkflowCompletion waits for a workflow run to complete, with a specified timeout
func WaitForWorkflowCompletion(repoSlug, runID string, timeoutMinutes int, verbose bool) error {
func WaitForWorkflowCompletion(ctx context.Context, repoSlug, runID string, timeoutMinutes int, verbose bool) error {
prAutomergeLog.Printf("Waiting for workflow completion: repo=%s, runID=%s, timeout=%d minutes", repoSlug, runID, timeoutMinutes)

timeout := time.Duration(timeoutMinutes) * time.Minute

return PollWithSignalHandling(PollOptions{
Ctx: ctx,
PollInterval: 10 * time.Second,
Timeout: timeout,
PollFunc: func() (PollResult, error) {
// Check workflow status
output, err := workflow.RunGH("Checking workflow status...", "run", "view", runID, "--repo", repoSlug, "--json", "status,conclusion")

PollFunc: func(ctx context.Context) (PollResult, error) {
// Check workflow status with context-aware GH execution.
// ctx is cancelled on Ctrl-C, which causes RunGHContext to abort the gh subprocess.
output, err := workflow.RunGHContext(ctx, "Checking workflow status...", "run", "view", runID, "--repo", repoSlug, "--json", "status,conclusion")
if err != nil {
return PollFailure, fmt.Errorf("failed to check workflow status: %w", err)
}
Expand Down
27 changes: 26 additions & 1 deletion pkg/cli/pr_automerge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package cli

import (
"context"
"errors"
"testing"
)

Expand All @@ -16,10 +18,33 @@ func TestWaitForWorkflowCompletionUsesSignalHandling(t *testing.T) {
// but we can verify that the timeout mechanism works, which confirms
// it's using the polling helper

err := WaitForWorkflowCompletion("nonexistent/repo", "12345", 0, false)
err := WaitForWorkflowCompletion(context.Background(), "nonexistent/repo", "12345", 0, false)

// Should timeout or fail to check workflow status
if err == nil {
t.Error("Expected error for nonexistent workflow, got nil")
}
}

// TestWaitForWorkflowCompletion_ContextCancellation verifies that WaitForWorkflowCompletion
// propagates cancellation when the context is cancelled, so callers (e.g. the repeat loop)
// can detect an intentional interruption and stop immediately.
func TestWaitForWorkflowCompletion_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// Cancel immediately so the poll loop exits on the first ctx.Done() check.
cancel()

err := WaitForWorkflowCompletion(ctx, "nonexistent/repo", "12345", 5, false)

if err == nil {
t.Fatal("Expected error on cancelled context, got nil")
}

// Must be either ErrInterrupted (from the poll select loop) or context.Canceled
// (from the PollFunc guard when ctx is already cancelled before polling begins).
// Both indicate an intentional interruption that callers should detect and propagate.
isInterruption := errors.Is(err, ErrInterrupted) || errors.Is(err, context.Canceled)
if !isInterruption {
t.Errorf("Expected interruption error (ErrInterrupted or context.Canceled) from WaitForWorkflowCompletion, got: %v", err)
}
}
33 changes: 26 additions & 7 deletions pkg/cli/retry.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"context"
"fmt"
"os"
"os/signal"
Expand All @@ -16,6 +17,8 @@ var retryLog = logger.New("cli:retry")

// RepeatOptions contains configuration for the repeat functionality
type RepeatOptions struct {
// Context for cancellation (optional, but recommended for proper Ctrl-C handling)
Ctx context.Context
// Number of times to repeat execution (0 = run once)
RepeatCount int
// Message to display when starting repeat mode
Expand Down Expand Up @@ -60,25 +63,41 @@ func ExecuteWithRepeat(options RepeatOptions) error {
}
fmt.Fprintln(output, console.FormatInfoMessage(startMsg))

// Use provided context or fall back to background context
ctx := options.Ctx
if ctx == nil {
ctx = context.Background()
}

// Set up signal handling for graceful shutdown
// Signal channel provides a fallback when no context is provided or for direct OS signals
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(sigChan)

// runCleanup executes the optional cleanup function
runCleanup := func() {
if options.CleanupFunc != nil {
retryLog.Print("Executing cleanup function")
options.CleanupFunc()
}
}

// Run the specified number of additional times
for i := 1; i <= options.RepeatCount; i++ {
select {
case <-ctx.Done():
retryLog.Printf("Context cancelled at iteration %d/%d", i, options.RepeatCount)
fmt.Fprintln(output, console.FormatInfoMessage("Received interrupt signal, stopping repeat..."))
runCleanup()
return ctx.Err()

case <-sigChan:
retryLog.Printf("Interrupt signal received at iteration %d/%d", i, options.RepeatCount)
fmt.Fprintln(output, console.FormatInfoMessage("Received interrupt signal, stopping repeat..."))
runCleanup()
return context.Canceled

// Execute cleanup function if provided
if options.CleanupFunc != nil {
retryLog.Print("Executing cleanup function")
options.CleanupFunc()
}

return nil
default:
retryLog.Printf("Starting iteration %d/%d", i, options.RepeatCount)
// Use provided repeat message or default
Expand Down
Loading
Loading