diff --git a/cmd/new.go b/cmd/new.go index 374e079..640dcd8 100644 --- a/cmd/new.go +++ b/cmd/new.go @@ -36,16 +36,19 @@ will be used as the parent.`, parent = args[1] } - if err := runNew(branchName, parent); err != nil { + gitClient := git.NewGitClient() + githubClient := github.NewGitHubClient() + + if err := runNew(gitClient, githubClient, branchName, parent); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } -func runNew(branchName string, explicitParent string) error { +func runNew(gitClient git.GitClient, githubClient github.GitHubClient, branchName string, explicitParent string) error { // Check if branch already exists - if git.BranchExists(branchName) { + if gitClient.BranchExists(branchName) { return fmt.Errorf("branch %s already exists", branchName) } @@ -55,12 +58,12 @@ func runNew(branchName string, explicitParent string) error { // Use explicitly provided parent parent = explicitParent // Verify parent exists - if !git.BranchExists(parent) { + if !gitClient.BranchExists(parent) { return fmt.Errorf("parent branch %s does not exist", parent) } } else { // Get current branch as parent - currentBranch, err := git.GetCurrentBranch() + currentBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } @@ -68,10 +71,10 @@ func runNew(branchName string, explicitParent string) error { // If current branch has no parent, check if it's the base branch // Otherwise use it as parent parent = currentBranch - currentParent := git.GetConfig(fmt.Sprintf("branch.%s.stackparent", currentBranch)) + currentParent := gitClient.GetConfig(fmt.Sprintf("branch.%s.stackparent", currentBranch)) // If we're not on a stack branch, use the base branch as parent - if currentParent == "" && currentBranch != stack.GetBaseBranch() { + if currentParent == "" && currentBranch != stack.GetBaseBranch(gitClient) { // Check if current branch IS the base branch or if we should use base parent = currentBranch } @@ -80,13 +83,13 @@ func runNew(branchName string, explicitParent string) error { fmt.Printf("Creating new branch %s from %s\n", branchName, parent) // Create the new branch - if err := git.CreateBranch(branchName, parent); err != nil { + if err := gitClient.CreateBranch(branchName, parent); err != nil { return fmt.Errorf("failed to create branch: %w", err) } // Set parent in git config configKey := fmt.Sprintf("branch.%s.stackparent", branchName) - if err := git.SetConfig(configKey, parent); err != nil { + if err := gitClient.SetConfig(configKey, parent); err != nil { return fmt.Errorf("failed to set parent config: %w", err) } @@ -95,7 +98,7 @@ func runNew(branchName string, explicitParent string) error { fmt.Println() // Show the full stack - if err := showStack(); err != nil { + if err := showStack(gitClient, githubClient); err != nil { // Don't fail if we can't show the stack, just warn fmt.Fprintf(os.Stderr, "Warning: failed to display stack: %v\n", err) } @@ -105,19 +108,19 @@ func runNew(branchName string, explicitParent string) error { } // showStack displays the current stack structure -func showStack() error { - currentBranch, err := git.GetCurrentBranch() +func showStack(gitClient git.GitClient, githubClient github.GitHubClient) error { + currentBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } - tree, err := stack.BuildStackTreeForBranch(currentBranch) + tree, err := stack.BuildStackTreeForBranch(gitClient, currentBranch) if err != nil { return fmt.Errorf("failed to build stack tree: %w", err) } // Fetch all PRs upfront for better performance - prCache, err := github.GetAllPRs() + prCache, err := githubClient.GetAllPRs() if err != nil { // If fetching PRs fails, just continue without PR info prCache = make(map[string]*github.PRInfo) @@ -126,7 +129,7 @@ func showStack() error { // Filter out branches with merged PRs from the tree (but keep current branch) tree = filterMergedBranchesForNew(tree, prCache, currentBranch) - printStackTree(tree, "", true, currentBranch, prCache) + printStackTree(gitClient, tree, "", true, currentBranch, prCache) return nil } @@ -170,16 +173,16 @@ func filterMergedBranchesForNew(node *stack.TreeNode, prCache map[string]*github } // printStackTree is a simplified version of the status tree printer -func printStackTree(node *stack.TreeNode, prefix string, isLast bool, currentBranch string, prCache map[string]*github.PRInfo) { +func printStackTree(gitClient git.GitClient, node *stack.TreeNode, prefix string, isLast bool, currentBranch string, prCache map[string]*github.PRInfo) { if node == nil { return } // Flatten the tree into a vertical list - printStackTreeVertical(node, currentBranch, prCache, false) + printStackTreeVertical(gitClient, node, currentBranch, prCache, false) } -func printStackTreeVertical(node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo, isPipe bool) { +func printStackTreeVertical(gitClient git.GitClient, node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo, isPipe bool) { if node == nil { return } @@ -191,7 +194,7 @@ func printStackTreeVertical(node *stack.TreeNode, currentBranch string, prCache // Get PR info from cache prInfo := "" - if node.Name != stack.GetBaseBranch() { + if node.Name != stack.GetBaseBranch(gitClient) { if pr, exists := prCache[node.Name]; exists { prInfo = fmt.Sprintf(" [%s :%s]", pr.URL, strings.ToLower(pr.State)) } @@ -207,6 +210,6 @@ func printStackTreeVertical(node *stack.TreeNode, currentBranch string, prCache // Print children vertically for _, child := range node.Children { - printStackTreeVertical(child, currentBranch, prCache, true) + printStackTreeVertical(gitClient, child, currentBranch, prCache, true) } } diff --git a/cmd/new_test.go b/cmd/new_test.go new file mode 100644 index 0000000..634849a --- /dev/null +++ b/cmd/new_test.go @@ -0,0 +1,231 @@ +package cmd + +import ( + "fmt" + "testing" + + "github.com/javoire/stackinator/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestRunNew(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + branchName string + explicitParent string + setupMocks func(*testutil.MockGitClient, *testutil.MockGitHubClient) + expectError bool + }{ + { + name: "create branch with explicit parent", + branchName: "feature-b", + explicitParent: "feature-a", + setupMocks: func(mockGit *testutil.MockGitClient, mockGH *testutil.MockGitHubClient) { + // Branch doesn't exist + mockGit.On("BranchExists", "feature-b").Return(false) + // Parent exists + mockGit.On("BranchExists", "feature-a").Return(true) + // Create branch + mockGit.On("CreateBranch", "feature-b", "feature-a").Return(nil) + // Set config + mockGit.On("SetConfig", "branch.feature-b.stackparent", "feature-a").Return(nil) + }, + expectError: false, + }, + { + name: "create branch from current", + branchName: "feature-b", + explicitParent: "", + setupMocks: func(mockGit *testutil.MockGitClient, mockGH *testutil.MockGitHubClient) { + // Branch doesn't exist + mockGit.On("BranchExists", "feature-b").Return(false) + // Get current branch + mockGit.On("GetCurrentBranch").Return("feature-a", nil) + // Check if current branch has parent + mockGit.On("GetConfig", "branch.feature-a.stackparent").Return("main") + // Create branch from current + mockGit.On("CreateBranch", "feature-b", "feature-a").Return(nil) + // Set config + mockGit.On("SetConfig", "branch.feature-b.stackparent", "feature-a").Return(nil) + }, + expectError: false, + }, + { + name: "error when branch exists", + branchName: "feature-a", + explicitParent: "main", + setupMocks: func(mockGit *testutil.MockGitClient, mockGH *testutil.MockGitHubClient) { + // Branch already exists + mockGit.On("BranchExists", "feature-a").Return(true) + }, + expectError: true, + }, + { + name: "error when parent doesn't exist", + branchName: "feature-b", + explicitParent: "non-existent", + setupMocks: func(mockGit *testutil.MockGitClient, mockGH *testutil.MockGitHubClient) { + // Branch doesn't exist + mockGit.On("BranchExists", "feature-b").Return(false) + // Parent doesn't exist + mockGit.On("BranchExists", "non-existent").Return(false) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + tt.setupMocks(mockGit, mockGH) + + // Set dryRun to true to skip the display logic at the end + dryRun = true + + err := runNew(mockGit, mockGH, tt.branchName, tt.explicitParent) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) + + // Reset dryRun + dryRun = false + }) + } +} + +func TestRunNewValidation(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + t.Run("validates branch name", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Branch already exists + mockGit.On("BranchExists", "existing-branch").Return(true) + + err := runNew(mockGit, mockGH, "existing-branch", "main") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "already exists") + + mockGit.AssertExpectations(t) + }) + + t.Run("validates parent exists", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Branch doesn't exist + mockGit.On("BranchExists", "new-branch").Return(false) + // Parent doesn't exist + mockGit.On("BranchExists", "non-existent-parent").Return(false) + + err := runNew(mockGit, mockGH, "new-branch", "non-existent-parent") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "does not exist") + + mockGit.AssertExpectations(t) + }) +} + +func TestRunNewSetConfig(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Branch doesn't exist + mockGit.On("BranchExists", "new-branch").Return(false) + // Parent exists + mockGit.On("BranchExists", "parent-branch").Return(true) + // Create branch + mockGit.On("CreateBranch", "new-branch", "parent-branch").Return(nil) + // Verify SetConfig is called with correct parameters + mockGit.On("SetConfig", "branch.new-branch.stackparent", "parent-branch").Return(nil) + + dryRun = true + err := runNew(mockGit, mockGH, "new-branch", "parent-branch") + dryRun = false + + assert.NoError(t, err) + mockGit.AssertExpectations(t) +} + +func TestRunNewFromCurrentBranch(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Branch doesn't exist + mockGit.On("BranchExists", "new-branch").Return(false) + // Get current branch + mockGit.On("GetCurrentBranch").Return("current-branch", nil) + // Current branch has a parent (it's in a stack) + mockGit.On("GetConfig", "branch.current-branch.stackparent").Return("main") + // Create branch from current + mockGit.On("CreateBranch", "new-branch", "current-branch").Return(nil) + // Set config + mockGit.On("SetConfig", "branch.new-branch.stackparent", "current-branch").Return(nil) + + dryRun = true + err := runNew(mockGit, mockGH, "new-branch", "") + dryRun = false + + assert.NoError(t, err) + mockGit.AssertExpectations(t) +} + +func TestRunNewErrorHandling(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + t.Run("error on CreateBranch failure", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + mockGit.On("BranchExists", "new-branch").Return(false) + mockGit.On("BranchExists", "parent").Return(true) + mockGit.On("CreateBranch", "new-branch", "parent").Return(fmt.Errorf("git error")) + + err := runNew(mockGit, mockGH, "new-branch", "parent") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create branch") + + mockGit.AssertExpectations(t) + }) + + t.Run("error on SetConfig failure", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + mockGit.On("BranchExists", "new-branch").Return(false) + mockGit.On("BranchExists", "parent").Return(true) + mockGit.On("CreateBranch", "new-branch", "parent").Return(nil) + mockGit.On("SetConfig", "branch.new-branch.stackparent", "parent").Return(fmt.Errorf("config error")) + + err := runNew(mockGit, mockGH, "new-branch", "parent") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to set parent config") + + mockGit.AssertExpectations(t) + }) +} + diff --git a/cmd/parent.go b/cmd/parent.go index 410d59c..e683fbf 100644 --- a/cmd/parent.go +++ b/cmd/parent.go @@ -18,22 +18,24 @@ is not part of a stack.`, Example: ` # Show parent of current branch stack parent`, Run: func(cmd *cobra.Command, args []string) { - if err := runParent(); err != nil { + gitClient := git.NewGitClient() + + if err := runParent(gitClient); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } -func runParent() error { +func runParent(gitClient git.GitClient) error { // Get current branch - currentBranch, err := git.GetCurrentBranch() + currentBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } // Get parent from git config - parent := git.GetConfig(fmt.Sprintf("branch.%s.stackparent", currentBranch)) + parent := gitClient.GetConfig(fmt.Sprintf("branch.%s.stackparent", currentBranch)) if parent == "" { fmt.Printf("%s (not in a stack)\n", currentBranch) diff --git a/cmd/prune.go b/cmd/prune.go index 2a1da42..01785fe 100644 --- a/cmd/prune.go +++ b/cmd/prune.go @@ -43,7 +43,10 @@ If a branch has unmerged commits locally, use --force to delete it anyway.`, # Preview what would be deleted stack prune --dry-run`, Run: func(cmd *cobra.Command, args []string) { - if err := runPrune(); err != nil { + gitClient := git.NewGitClient() + githubClient := github.NewGitHubClient() + + if err := runPrune(gitClient, githubClient); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } @@ -55,15 +58,15 @@ func init() { pruneCmd.Flags().BoolVarP(&pruneAll, "all", "a", false, "Check all local branches, not just stack branches") } -func runPrune() error { +func runPrune(gitClient git.GitClient, githubClient github.GitHubClient) error { // Get current branch so we don't delete it - currentBranch, err := git.GetCurrentBranch() + currentBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } // Get base branch to exclude it from pruning - baseBranch := stack.GetBaseBranch() + baseBranch := stack.GetBaseBranch(gitClient) // Start PR fetch in parallel with branch loading (PR fetch is the slowest operation) var wg sync.WaitGroup @@ -73,7 +76,7 @@ func runPrune() error { wg.Add(1) go func() { defer wg.Done() - prCache, prErr = github.GetAllPRs() + prCache, prErr = githubClient.GetAllPRs() }() // Get branches to check (runs in parallel with PR fetch) @@ -81,7 +84,7 @@ func runPrune() error { var branchErr error if pruneAll { // Check all local branches - branchNames, branchErr = git.ListBranches() + branchNames, branchErr = gitClient.ListBranches() if branchErr != nil { wg.Wait() // Wait for PR fetch before returning return fmt.Errorf("failed to get branches: %w", branchErr) @@ -98,7 +101,7 @@ func runPrune() error { } else { // Check only stack branches var stackBranches []stack.StackBranch - stackBranches, branchErr = stack.GetStackBranches() + stackBranches, branchErr = stack.GetStackBranches(gitClient) if branchErr != nil { wg.Wait() // Wait for PR fetch before returning return fmt.Errorf("failed to get stack branches: %w", branchErr) @@ -165,9 +168,9 @@ func runPrune() error { // Remove from stack tracking (if in stack) configKey := fmt.Sprintf("branch.%s.stackparent", branch) - if git.GetConfig(configKey) != "" { + if gitClient.GetConfig(configKey) != "" { fmt.Println(" Removing from stack tracking...") - if err := git.UnsetConfig(configKey); err != nil { + if err := gitClient.UnsetConfig(configKey); err != nil { fmt.Fprintf(os.Stderr, " Warning: failed to remove stack config: %v\n", err) } } @@ -183,9 +186,9 @@ func runPrune() error { fmt.Println(" Deleting branch...") var deleteErr error if pruneForce { - deleteErr = deleteBranchForce(branch) + deleteErr = deleteBranchForce(gitClient, branch) } else { - deleteErr = deleteBranch(branch) + deleteErr = deleteBranch(gitClient, branch) } if deleteErr != nil { @@ -205,17 +208,17 @@ func runPrune() error { } // deleteBranch deletes a branch using 'git branch -d' (safe delete) -func deleteBranch(name string) error { +func deleteBranch(gitClient git.GitClient, name string) error { if verbose { fmt.Printf(" [git] branch -d %s\n", name) } - return git.DeleteBranch(name) + return gitClient.DeleteBranch(name) } // deleteBranchForce deletes a branch using 'git branch -D' (force delete) -func deleteBranchForce(name string) error { +func deleteBranchForce(gitClient git.GitClient, name string) error { if verbose { fmt.Printf(" [git] branch -D %s\n", name) } - return git.DeleteBranchForce(name) + return gitClient.DeleteBranchForce(name) } diff --git a/cmd/rename.go b/cmd/rename.go index 677cc59..abd1df8 100644 --- a/cmd/rename.go +++ b/cmd/rename.go @@ -5,6 +5,7 @@ import ( "os" "github.com/javoire/stackinator/internal/git" + "github.com/javoire/stackinator/internal/github" "github.com/javoire/stackinator/internal/stack" "github.com/spf13/cobra" ) @@ -29,33 +30,36 @@ The command must be run while on the branch you want to rename.`, Run: func(cmd *cobra.Command, args []string) { newName := args[0] - if err := runRename(newName); err != nil { + gitClient := git.NewGitClient() + githubClient := github.NewGitHubClient() + + if err := runRename(gitClient, githubClient, newName); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } -func runRename(newName string) error { +func runRename(gitClient git.GitClient, githubClient github.GitHubClient, newName string) error { // Get current branch - oldName, err := git.GetCurrentBranch() + oldName, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } // Validate old branch is in the stack - oldParent := git.GetConfig(fmt.Sprintf("branch.%s.stackparent", oldName)) + oldParent := gitClient.GetConfig(fmt.Sprintf("branch.%s.stackparent", oldName)) if oldParent == "" { return fmt.Errorf("current branch %s is not part of a stack (no stackparent configured)", oldName) } // Check if new name already exists - if git.BranchExists(newName) { + if gitClient.BranchExists(newName) { return fmt.Errorf("branch %s already exists", newName) } // Get all children of the current branch - children, err := stack.GetChildrenOf(oldName) + children, err := stack.GetChildrenOf(gitClient, oldName) if err != nil { return fmt.Errorf("failed to get children: %w", err) } @@ -66,7 +70,7 @@ func runRename(newName string) error { } // Rename the branch - if err := git.RenameBranch(oldName, newName); err != nil { + if err := gitClient.RenameBranch(oldName, newName); err != nil { return fmt.Errorf("failed to rename branch: %w", err) } @@ -74,11 +78,11 @@ func runRename(newName string) error { oldConfigKey := fmt.Sprintf("branch.%s.stackparent", oldName) newConfigKey := fmt.Sprintf("branch.%s.stackparent", newName) - if err := git.SetConfig(newConfigKey, oldParent); err != nil { + if err := gitClient.SetConfig(newConfigKey, oldParent); err != nil { return fmt.Errorf("failed to set new parent config: %w", err) } - if err := git.UnsetConfig(oldConfigKey); err != nil { + if err := gitClient.UnsetConfig(oldConfigKey); err != nil { // This might fail if the branch was just renamed and git already handled it // Don't fail the whole operation if verbose { @@ -89,7 +93,7 @@ func runRename(newName string) error { // Update all children to point to the new name for _, child := range children { childConfigKey := fmt.Sprintf("branch.%s.stackparent", child.Name) - if err := git.SetConfig(childConfigKey, newName); err != nil { + if err := gitClient.SetConfig(childConfigKey, newName); err != nil { return fmt.Errorf("failed to update child %s: %w", child.Name, err) } fmt.Printf(" ✓ Updated child %s to point to %s\n", child.Name, newName) @@ -100,7 +104,7 @@ func runRename(newName string) error { fmt.Println() // Show the updated stack - if err := showStack(); err != nil { + if err := showStack(gitClient, githubClient); err != nil { // Don't fail if we can't show the stack, just warn fmt.Fprintf(os.Stderr, "Warning: failed to display stack: %v\n", err) } diff --git a/cmd/reparent.go b/cmd/reparent.go index 9859a54..ecd5a34 100644 --- a/cmd/reparent.go +++ b/cmd/reparent.go @@ -32,22 +32,25 @@ a feature is based on.`, Run: func(cmd *cobra.Command, args []string) { newParent := args[0] - if err := runReparent(newParent); err != nil { + gitClient := git.NewGitClient() + githubClient := github.NewGitHubClient() + + if err := runReparent(gitClient, githubClient, newParent); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } }, } -func runReparent(newParent string) error { +func runReparent(gitClient git.GitClient, githubClient github.GitHubClient, newParent string) error { // Get current branch - currentBranch, err := git.GetCurrentBranch() + currentBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } // Check if current branch is a stack branch - currentParent := git.GetConfig(fmt.Sprintf("branch.%s.stackparent", currentBranch)) + currentParent := gitClient.GetConfig(fmt.Sprintf("branch.%s.stackparent", currentBranch)) if currentParent == "" { return fmt.Errorf("branch %s is not part of a stack (no parent set)", currentBranch) } @@ -59,7 +62,7 @@ func runReparent(newParent string) error { } // Verify new parent branch exists - if !git.BranchExists(newParent) { + if !gitClient.BranchExists(newParent) { return fmt.Errorf("new parent branch %s does not exist", newParent) } @@ -69,7 +72,7 @@ func runReparent(newParent string) error { } // Check if new parent is a descendant of current branch (would create cycle) - if isDescendant(currentBranch, newParent) { + if isDescendant(gitClient, currentBranch, newParent) { return fmt.Errorf("cannot reparent to %s: it is a descendant of %s (would create a cycle)", newParent, currentBranch) } @@ -77,12 +80,12 @@ func runReparent(newParent string) error { // Update git config configKey := fmt.Sprintf("branch.%s.stackparent", currentBranch) - if err := git.SetConfig(configKey, newParent); err != nil { + if err := gitClient.SetConfig(configKey, newParent); err != nil { return fmt.Errorf("failed to update parent config: %w", err) } // Check if there's a PR for this branch - pr, err := github.GetPRForBranch(currentBranch) + pr, err := githubClient.GetPRForBranch(currentBranch) if err != nil { // Error fetching PR info, but config was updated successfully fmt.Printf("✓ Updated parent to %s\n", newParent) @@ -94,7 +97,7 @@ func runReparent(newParent string) error { // PR exists, update its base fmt.Printf("Updating PR #%d base: %s -> %s\n", pr.Number, pr.Base, newParent) - if err := github.UpdatePRBase(pr.Number, newParent); err != nil { + if err := githubClient.UpdatePRBase(pr.Number, newParent); err != nil { // Config was updated but PR base update failed fmt.Printf("✓ Updated parent to %s\n", newParent) return fmt.Errorf("failed to update PR base: %w", err) @@ -116,7 +119,7 @@ func runReparent(newParent string) error { } // isDescendant checks if possibleDescendant is a descendant of ancestor in the stack -func isDescendant(ancestor, possibleDescendant string) bool { +func isDescendant(gitClient git.GitClient, ancestor, possibleDescendant string) bool { // Walk up from possibleDescendant to see if we reach ancestor current := possibleDescendant visited := make(map[string]bool) @@ -129,7 +132,7 @@ func isDescendant(ancestor, possibleDescendant string) bool { visited[current] = true // Get parent of current - parent := git.GetConfig(fmt.Sprintf("branch.%s.stackparent", current)) + parent := gitClient.GetConfig(fmt.Sprintf("branch.%s.stackparent", current)) if parent == "" { // Reached the top of the stack without finding ancestor return false diff --git a/cmd/root.go b/cmd/root.go index 0c3d067..eec3384 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -50,7 +50,8 @@ The tool helps you create, navigate, and sync stacked branches with minimal over spinner.Enabled = !verbose // Validate we're in a git repository - if _, err := git.GetRepoRoot(); err != nil { + gitClient := git.NewGitClient() + if _, err := gitClient.GetRepoRoot(); err != nil { fmt.Fprintf(os.Stderr, "Error: not in a git repository\n") os.Exit(1) } diff --git a/cmd/status.go b/cmd/status.go index d2b683f..9eaf51c 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -40,7 +40,10 @@ This helps you visualize your stack and see which branches have PRs.`, # | # feature-auth-tests *`, Run: func(cmd *cobra.Command, args []string) { - if err := runStatus(); err != nil { + gitClient := git.NewGitClient() + githubClient := github.NewGitHubClient() + + if err := runStatus(gitClient, githubClient); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } @@ -51,7 +54,7 @@ func init() { statusCmd.Flags().BoolVar(&noPR, "no-pr", false, "Skip fetching PR information (faster)") } -func runStatus() error { +func runStatus(gitClient git.GitClient, githubClient github.GitHubClient) error { var currentBranch string var stackBranches []stack.StackBranch var tree *stack.TreeNode @@ -68,7 +71,7 @@ func runStatus() error { wg.Add(2) go func() { defer wg.Done() - prCache, prErr = github.GetAllPRs() + prCache, prErr = githubClient.GetAllPRs() if prErr != nil { // If fetching fails, fall back to empty cache prCache = make(map[string]*github.PRInfo) @@ -77,7 +80,7 @@ func runStatus() error { go func() { defer wg.Done() // Fetch latest changes from origin (needed for sync issue detection) - _ = git.Fetch() + _ = gitClient.Fetch() fetchDone = true }() } else { @@ -88,13 +91,13 @@ func runStatus() error { if err := spinner.WrapWithAutoDelay("Loading stack...", 300*time.Millisecond, func() error { // Get current branch var err error - currentBranch, err = git.GetCurrentBranch() + currentBranch, err = gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } // Check if there are any stack branches - stackBranches, err = stack.GetStackBranches() + stackBranches, err = stack.GetStackBranches(gitClient) if err != nil { return fmt.Errorf("failed to get stack branches: %w", err) } @@ -104,7 +107,7 @@ func runStatus() error { } // Build stack tree for current branch only - tree, err = stack.BuildStackTreeForBranch(currentBranch) + tree, err = stack.BuildStackTreeForBranch(gitClient, currentBranch) if err != nil { return fmt.Errorf("failed to build stack tree: %w", err) } @@ -135,7 +138,7 @@ func runStatus() error { // Print the tree fmt.Println() - printTree(tree, "", true, currentBranch, prCache) + printTree(gitClient, tree, "", true, currentBranch, prCache) // Check for sync issues (skip if --no-pr) if !noPR { @@ -154,7 +157,7 @@ func runStatus() error { var syncResult *syncIssuesResult if err := spinner.WrapWithAutoDelayAndProgress("Checking for sync issues...", 300*time.Millisecond, func(progress spinner.ProgressFunc) error { var err error - syncResult, err = detectSyncIssues(treeBranches, prCache, progress, fetchDone) + syncResult, err = detectSyncIssues(gitClient, treeBranches, prCache, progress, fetchDone) return err }); err != nil { // Don't fail on detection errors, just skip the check @@ -232,16 +235,16 @@ func filterMergedBranches(node *stack.TreeNode, prCache map[string]*github.PRInf return node } -func printTree(node *stack.TreeNode, prefix string, isLast bool, currentBranch string, prCache map[string]*github.PRInfo) { +func printTree(gitClient git.GitClient, node *stack.TreeNode, prefix string, isLast bool, currentBranch string, prCache map[string]*github.PRInfo) { if node == nil { return } // Flatten the tree into a vertical list - printTreeVertical(node, currentBranch, prCache, false) + printTreeVertical(gitClient, node, currentBranch, prCache, false) } -func printTreeVertical(node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo, isPipe bool) { +func printTreeVertical(gitClient git.GitClient, node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo, isPipe bool) { if node == nil { return } @@ -254,7 +257,7 @@ func printTreeVertical(node *stack.TreeNode, currentBranch string, prCache map[s // Get PR info from cache prInfo := "" - if node.Name != stack.GetBaseBranch() { + if node.Name != stack.GetBaseBranch(gitClient) { if pr, exists := prCache[node.Name]; exists { prInfo = fmt.Sprintf(" [%s :%s]", pr.URL, strings.ToLower(pr.State)) } @@ -270,7 +273,7 @@ func printTreeVertical(node *stack.TreeNode, currentBranch string, prCache map[s // Print children vertically for _, child := range node.Children { - printTreeVertical(child, currentBranch, prCache, true) + printTreeVertical(gitClient, child, currentBranch, prCache, true) } } @@ -282,7 +285,7 @@ type syncIssuesResult struct { // detectSyncIssues checks if any branches are out of sync and returns the issues (doesn't print) // If skipFetch is true, assumes git fetch was already called (to avoid redundant network calls) -func detectSyncIssues(stackBranches []stack.StackBranch, prCache map[string]*github.PRInfo, progress spinner.ProgressFunc, skipFetch bool) (*syncIssuesResult, error) { +func detectSyncIssues(gitClient git.GitClient, stackBranches []stack.StackBranch, prCache map[string]*github.PRInfo, progress spinner.ProgressFunc, skipFetch bool) (*syncIssuesResult, error) { var issues []string var mergedBranches []string @@ -292,7 +295,7 @@ func detectSyncIssues(stackBranches []stack.StackBranch, prCache map[string]*git if verbose { fmt.Println("Fetching latest changes from origin...") } - _ = git.Fetch() + _ = gitClient.Fetch() } if verbose { @@ -317,7 +320,7 @@ func detectSyncIssues(stackBranches []stack.StackBranch, prCache map[string]*git } // Check if parent has a merged PR (child needs to be updated) - if branch.Parent != stack.GetBaseBranch() { + if branch.Parent != stack.GetBaseBranch(gitClient) { if parentPR, exists := prCache[branch.Parent]; exists && parentPR.State == "MERGED" { if verbose { fmt.Printf(" ✗ Parent '%s' has merged PR #%d\n", branch.Parent, parentPR.Number) @@ -350,7 +353,7 @@ func detectSyncIssues(stackBranches []stack.StackBranch, prCache map[string]*git if verbose { fmt.Printf(" Checking if branch is behind parent %s...\n", branch.Parent) } - behind, err := git.IsCommitsBehind(branch.Name, branch.Parent) + behind, err := gitClient.IsCommitsBehind(branch.Name, branch.Parent) if err == nil && behind { if verbose { fmt.Printf(" ✗ Branch is behind %s (needs rebase)\n", branch.Parent) diff --git a/cmd/status_test.go b/cmd/status_test.go new file mode 100644 index 0000000..8108949 --- /dev/null +++ b/cmd/status_test.go @@ -0,0 +1,270 @@ +package cmd + +import ( + "testing" + + "github.com/javoire/stackinator/internal/github" + "github.com/javoire/stackinator/internal/stack" + "github.com/javoire/stackinator/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestRunStatus(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + setupMocks func(*testutil.MockGitClient, *testutil.MockGitHubClient) + expectError bool + }{ + { + name: "display simple stack", + setupMocks: func(mockGit *testutil.MockGitClient, mockGH *testutil.MockGitHubClient) { + // Get current branch + mockGit.On("GetCurrentBranch").Return("feature-a", nil) + // Get stack branches (called multiple times in BuildStackTreeForBranch) + stackParents := map[string]string{ + "feature-a": "main", + } + mockGit.On("GetAllStackParents").Return(stackParents, nil).Times(3) // Called 3 times + // Get base branch + mockGit.On("GetConfig", "stack.baseBranch").Return("") + mockGit.On("GetDefaultBranch").Return("main") + // Note: GetAllPRs is NOT called because noPR is true + }, + expectError: false, + }, + { + name: "no stack branches", + setupMocks: func(mockGit *testutil.MockGitClient, mockGH *testutil.MockGitHubClient) { + // Get current branch + mockGit.On("GetCurrentBranch").Return("main", nil) + // Get stack branches (empty) + mockGit.On("GetAllStackParents").Return(make(map[string]string), nil) + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + tt.setupMocks(mockGit, mockGH) + + // Set noPR to true to skip PR fetching in parallel goroutines + noPR = true + + err := runStatus(mockGit, mockGH) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) + + // Reset noPR + noPR = false + }) + } +} + +func TestFilterMergedBranches(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + tree *stack.TreeNode + prCache map[string]*github.PRInfo + currentBranch string + expectFiltered bool + expectedBranches []string + }{ + { + name: "keep merged branch with children", + tree: &stack.TreeNode{ + Name: "main", + Children: []*stack.TreeNode{ + { + Name: "feature-a", + Children: []*stack.TreeNode{ + {Name: "feature-b", Children: nil}, + }, + }, + }, + }, + prCache: map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "MERGED", "main", "Feature A", "url"), + "feature-b": testutil.NewPRInfo(2, "OPEN", "feature-a", "Feature B", "url"), + }, + currentBranch: "feature-b", + expectedBranches: []string{"main", "feature-a", "feature-b"}, // Keep feature-a because it has children + }, + { + name: "filter merged leaf branch", + tree: &stack.TreeNode{ + Name: "main", + Children: []*stack.TreeNode{ + { + Name: "feature-a", + Children: nil, + }, + }, + }, + prCache: map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "MERGED", "main", "Feature A", "url"), + }, + currentBranch: "main", + expectedBranches: []string{"main"}, // Filter out feature-a because it's a merged leaf + }, + { + name: "keep current branch even if merged", + tree: &stack.TreeNode{ + Name: "main", + Children: []*stack.TreeNode{ + { + Name: "feature-a", + Children: nil, + }, + }, + }, + prCache: map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "MERGED", "main", "Feature A", "url"), + }, + currentBranch: "feature-a", + expectedBranches: []string{"main", "feature-a"}, // Keep feature-a because it's current branch + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filtered := filterMergedBranches(tt.tree, tt.prCache, tt.currentBranch) + + // Collect all branch names from filtered tree + var branches []string + var collectBranches func(*stack.TreeNode) + collectBranches = func(node *stack.TreeNode) { + if node == nil { + return + } + branches = append(branches, node.Name) + for _, child := range node.Children { + collectBranches(child) + } + } + collectBranches(filtered) + + assert.Equal(t, tt.expectedBranches, branches) + }) + } +} + +func TestGetAllBranchNamesFromTree(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tree := &stack.TreeNode{ + Name: "main", + Children: []*stack.TreeNode{ + { + Name: "feature-a", + Children: []*stack.TreeNode{ + {Name: "feature-b", Children: nil}, + }, + }, + {Name: "feature-c", Children: nil}, + }, + } + + branches := getAllBranchNamesFromTree(tree) + + assert.Len(t, branches, 4) + assert.Contains(t, branches, "main") + assert.Contains(t, branches, "feature-a") + assert.Contains(t, branches, "feature-b") + assert.Contains(t, branches, "feature-c") +} + +func TestDetectSyncIssues(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + stackBranches []stack.StackBranch + prCache map[string]*github.PRInfo + setupMocks func(*testutil.MockGitClient) + expectedIssues int + expectedMerged int + }{ + { + name: "branch behind parent", + stackBranches: []stack.StackBranch{ + {Name: "feature-a", Parent: "main"}, + }, + prCache: make(map[string]*github.PRInfo), + setupMocks: func(mockGit *testutil.MockGitClient) { + mockGit.On("IsCommitsBehind", "feature-a", "main").Return(true, nil) + }, + expectedIssues: 1, + expectedMerged: 0, + }, + { + name: "branch with merged PR", + stackBranches: []stack.StackBranch{ + {Name: "feature-a", Parent: "main"}, + }, + prCache: map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "MERGED", "main", "Feature A", "url"), + }, + setupMocks: func(mockGit *testutil.MockGitClient) { + // No calls expected for merged branches + }, + expectedIssues: 0, + expectedMerged: 1, + }, + { + name: "parent PR merged", + stackBranches: []stack.StackBranch{ + {Name: "feature-b", Parent: "feature-a"}, + }, + prCache: map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "MERGED", "main", "Feature A", "url"), + }, + setupMocks: func(mockGit *testutil.MockGitClient) { + mockGit.On("GetDefaultBranch").Return("main") + mockGit.On("IsCommitsBehind", "feature-b", "feature-a").Return(false, nil) + }, + expectedIssues: 1, // Issue because parent is merged + expectedMerged: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + tt.setupMocks(mockGit) + + // Mock GetBaseBranch calls + mockGit.On("GetConfig", "stack.baseBranch").Return("").Maybe() + mockGit.On("GetDefaultBranch").Return("main").Maybe() + + nopProgress := func(msg string) {} // No-op progress function + result, err := detectSyncIssues(mockGit, tt.stackBranches, tt.prCache, nopProgress, true) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Len(t, result.issues, tt.expectedIssues, "Expected %d issues, got %d", tt.expectedIssues, len(result.issues)) + assert.Len(t, result.mergedBranches, tt.expectedMerged, "Expected %d merged branches, got %d", tt.expectedMerged, len(result.mergedBranches)) + + mockGit.AssertExpectations(t) + }) + } +} + diff --git a/cmd/sync.go b/cmd/sync.go index 55e72cd..c005523 100644 --- a/cmd/sync.go +++ b/cmd/sync.go @@ -49,7 +49,10 @@ Uncommitted changes are automatically stashed and reapplied (using --autostash). git checkout main && git pull stack sync`, Run: func(cmd *cobra.Command, args []string) { - if err := runSync(); err != nil { + gitClient := git.NewGitClient() + githubClient := github.NewGitHubClient() + + if err := runSync(gitClient, githubClient); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } @@ -60,15 +63,15 @@ func init() { syncCmd.Flags().BoolVarP(&syncForce, "force", "f", false, "Force push even if local and remote have diverged (use with caution)") } -func runSync() error { +func runSync(gitClient git.GitClient, githubClient github.GitHubClient) error { // Get current branch so we can return to it - originalBranch, err := git.GetCurrentBranch() + originalBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } // Check if working tree is clean and stash if needed - clean, err := git.IsWorkingTreeClean() + clean, err := gitClient.IsWorkingTreeClean() if err != nil { return fmt.Errorf("failed to check working tree status: %w", err) } @@ -76,7 +79,7 @@ func runSync() error { stashed := false if !clean { fmt.Println("Stashing uncommitted changes...") - if err := git.Stash("stack-sync-autostash"); err != nil { + if err := gitClient.Stash("stack-sync-autostash"); err != nil { return fmt.Errorf("failed to stash changes: %w", err) } stashed = true @@ -90,7 +93,7 @@ func runSync() error { defer func() { if stashed && !success { fmt.Println("\nRestoring stashed changes...") - if err := git.StashPop(); err != nil { + if err := gitClient.StashPop(); err != nil { fmt.Fprintf(os.Stderr, "Warning: failed to restore stashed changes: %v\n", err) fmt.Fprintf(os.Stderr, "Run 'git stash pop' manually to restore your changes\n") } @@ -99,8 +102,8 @@ func runSync() error { // Check if current branch is in a stack BEFORE doing any network operations // This allows us to prompt the user immediately if needed - baseBranch := stack.GetBaseBranch() - parent := git.GetConfig(fmt.Sprintf("branch.%s.stackparent", originalBranch)) + baseBranch := stack.GetBaseBranch(gitClient) + parent := gitClient.GetConfig(fmt.Sprintf("branch.%s.stackparent", originalBranch)) if parent == "" && originalBranch != baseBranch { fmt.Printf("Branch '%s' is not in a stack.\n", originalBranch) @@ -120,7 +123,7 @@ func runSync() error { // Set the parent configKey := fmt.Sprintf("branch.%s.stackparent", originalBranch) - if err := git.SetConfig(configKey, baseBranch); err != nil { + if err := gitClient.SetConfig(configKey, baseBranch); err != nil { return fmt.Errorf("failed to set parent: %w", err) } fmt.Printf("✓ Added '%s' to stack with parent '%s'\n", originalBranch, baseBranch) @@ -140,16 +143,16 @@ func runSync() error { wg.Add(2) go func() { defer wg.Done() - fetchErr = git.Fetch() + fetchErr = gitClient.Fetch() }() go func() { defer wg.Done() - prCache, prErr = github.GetAllPRs() + prCache, prErr = githubClient.GetAllPRs() }() // While network operations run in background, do local work // Get only branches in the current branch's stack - chain, err := stack.GetStackChain(originalBranch) + chain, err := stack.GetStackChain(gitClient, originalBranch) if err != nil { return fmt.Errorf("failed to get stack chain: %w", err) } @@ -168,7 +171,7 @@ func runSync() error { } // Get all stack branches and filter to current stack only - allStackBranches, err := stack.GetStackBranches() + allStackBranches, err := stack.GetStackBranches(gitClient) if err != nil { return fmt.Errorf("failed to get stack branches: %w", err) } @@ -187,14 +190,14 @@ func runSync() error { } // Check if any branches in the current stack are in worktrees - worktrees, err := git.GetWorktreeBranches() + worktrees, err := gitClient.GetWorktreeBranches() if err != nil { // Non-fatal, continue without worktree detection worktrees = make(map[string]string) } // Get current worktree path to check if we're already in the right place - currentWorktreePath, err := git.GetCurrentWorktreePath() + currentWorktreePath, err := gitClient.GetCurrentWorktreePath() if err != nil { // Non-fatal, continue without worktree path detection currentWorktreePath = "" @@ -238,7 +241,7 @@ func runSync() error { } // Get all remote branches in one call (more efficient than checking each branch individually) - remoteBranches := git.GetRemoteBranchesSet() + remoteBranches := gitClient.GetRemoteBranchesSet() fmt.Printf("Processing %d branch(es)...\n\n", len(sorted)) @@ -257,7 +260,7 @@ func runSync() error { fmt.Printf("%s Skipping %s (PR #%d is merged)...\n", progress, branch.Name, pr.Number) fmt.Printf(" Removing from stack tracking...\n") configKey := fmt.Sprintf("branch.%s.stackparent", branch.Name) - if err := git.UnsetConfig(configKey); err != nil { + if err := gitClient.UnsetConfig(configKey); err != nil { fmt.Fprintf(os.Stderr, " Warning: failed to remove stack config: %v\n", err) } else { fmt.Printf(" ✓ Removed. You can delete this branch with: git branch -d %s\n", branch.Name) @@ -278,14 +281,14 @@ func runSync() error { oldParent = branch.Parent // Update parent to grandparent - grandparent := git.GetConfig(fmt.Sprintf("branch.%s.stackparent", branch.Parent)) + grandparent := gitClient.GetConfig(fmt.Sprintf("branch.%s.stackparent", branch.Parent)) if grandparent == "" { - grandparent = stack.GetBaseBranch() + grandparent = stack.GetBaseBranch(gitClient) } fmt.Printf(" Updating parent from %s to %s\n", branch.Parent, grandparent) configKey := fmt.Sprintf("branch.%s.stackparent", branch.Name) - if err := git.SetConfig(configKey, grandparent); err != nil { + if err := gitClient.SetConfig(configKey, grandparent); err != nil { fmt.Fprintf(os.Stderr, " Warning: failed to update parent config: %v\n", err) } else { branch.Parent = grandparent @@ -293,7 +296,7 @@ func runSync() error { } // Checkout the branch - if err := git.CheckoutBranch(branch.Name); err != nil { + if err := gitClient.CheckoutBranch(branch.Name); err != nil { return fmt.Errorf("failed to checkout %s: %w", branch.Name, err) } @@ -302,19 +305,19 @@ func runSync() error { branchExistsOnRemote := remoteBranches[branch.Name] if branchExistsOnRemote && !syncForce { // Check if local and remote have diverged - localHash, err := git.GetCommitHash(branch.Name) + localHash, err := gitClient.GetCommitHash(branch.Name) if err != nil { return fmt.Errorf("failed to get local commit hash: %w", err) } - remoteHash, err := git.GetCommitHash(remoteBranch) + remoteHash, err := gitClient.GetCommitHash(remoteBranch) if err != nil { return fmt.Errorf("failed to get remote commit hash: %w", err) } if localHash != remoteHash { // Check merge base to determine relationship - mergeBase, err := git.GetMergeBase(branch.Name, remoteBranch) + mergeBase, err := gitClient.GetMergeBase(branch.Name, remoteBranch) if err != nil { return fmt.Errorf("failed to get merge base: %w", err) } @@ -327,7 +330,7 @@ func runSync() error { } else if mergeBase == localHash { // Local is behind remote (safe to fast-forward) fmt.Printf(" Fast-forwarding to origin/%s...\n", branch.Name) - if err := git.ResetToRemote(branch.Name); err != nil { + if err := gitClient.ResetToRemote(branch.Name); err != nil { return fmt.Errorf("failed to fast-forward: %w", err) } } else { @@ -374,9 +377,9 @@ func runSync() error { // Parent was merged - use --onto to handle squash merge // This excludes commits from oldParent that are now in rebaseTarget fmt.Printf(" Using --onto to handle squash merge (excluding commits from %s)\n", oldParent) - return git.RebaseOnto(rebaseTarget, oldParent, branch.Name) + return gitClient.RebaseOnto(rebaseTarget, oldParent, branch.Name) } - return git.Rebase(rebaseTarget) + return gitClient.Rebase(rebaseTarget) }, ); err != nil { fmt.Fprintf(os.Stderr, " Please resolve conflicts and run 'git rebase --continue', then run 'stack sync' again\n") @@ -394,7 +397,7 @@ func runSync() error { if git.Verbose { fmt.Printf(" Using --force (bypassing safety checks)\n") } - return git.ForcePush(branch.Name) + return gitClient.ForcePush(branch.Name) } // Fetch one more time right before push to ensure --force-with-lease has fresh tracking info @@ -402,7 +405,7 @@ func runSync() error { if git.Verbose { fmt.Printf(" Refreshing remote tracking ref before push...\n") } - if err := git.FetchBranch(branch.Name); err != nil { + if err := gitClient.FetchBranch(branch.Name); err != nil { // Non-fatal, continue with push if git.Verbose { fmt.Fprintf(os.Stderr, " Note: could not refresh tracking ref: %v\n", err) @@ -410,7 +413,7 @@ func runSync() error { } // Use --force-with-lease (safe force push) - return git.Push(branch.Name, true) + return gitClient.Push(branch.Name, true) }, ) @@ -431,7 +434,7 @@ func runSync() error { if pr != nil { if pr.Base != branch.Parent { fmt.Printf(" Updating PR #%d base from %s to %s...\n", pr.Number, pr.Base, branch.Parent) - if err := github.UpdatePRBase(pr.Number, branch.Parent); err != nil { + if err := githubClient.UpdatePRBase(pr.Number, branch.Parent); err != nil { fmt.Fprintf(os.Stderr, " Warning: failed to update PR base: %v\n", err) } else { fmt.Printf(" ✓ PR #%d updated\n", pr.Number) @@ -448,14 +451,14 @@ func runSync() error { // Return to original branch fmt.Printf("Returning to %s...\n", originalBranch) - if err := git.CheckoutBranch(originalBranch); err != nil { + if err := gitClient.CheckoutBranch(originalBranch); err != nil { fmt.Fprintf(os.Stderr, "Warning: failed to return to original branch: %v\n", err) } fmt.Println() // Display the updated stack status (reuse prCache to avoid redundant API call) - if err := displayStatusAfterSync(prCache); err != nil { + if err := displayStatusAfterSync(gitClient, githubClient, prCache); err != nil { // Don't fail if we can't display status, just warn fmt.Fprintf(os.Stderr, "Warning: failed to display stack status: %v\n", err) } @@ -467,7 +470,7 @@ func runSync() error { if stashed { fmt.Println() fmt.Println("Restoring stashed changes...") - if err := git.StashPop(); err != nil { + if err := gitClient.StashPop(); err != nil { fmt.Fprintf(os.Stderr, "Warning: failed to restore stashed changes: %v\n", err) fmt.Fprintf(os.Stderr, "Run 'git stash pop' manually to restore your changes\n") } @@ -481,13 +484,13 @@ func runSync() error { // displayStatusAfterSync shows the stack tree after a successful sync // It reuses the prCache from earlier to avoid a redundant API call -func displayStatusAfterSync(prCache map[string]*github.PRInfo) error { - currentBranch, err := git.GetCurrentBranch() +func displayStatusAfterSync(gitClient git.GitClient, githubClient github.GitHubClient, prCache map[string]*github.PRInfo) error { + currentBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } - tree, err := stack.BuildStackTreeForBranch(currentBranch) + tree, err := stack.BuildStackTreeForBranch(gitClient, currentBranch) if err != nil { return fmt.Errorf("failed to build stack tree: %w", err) } @@ -496,7 +499,7 @@ func displayStatusAfterSync(prCache map[string]*github.PRInfo) error { tree = filterMergedBranchesForSync(tree, prCache) // Print the tree - printTreeForSync(tree, currentBranch, prCache) + printTreeForSync(gitClient, tree, currentBranch, prCache) return nil } @@ -534,14 +537,14 @@ func filterMergedBranchesForSync(node *stack.TreeNode, prCache map[string]*githu } // printTreeForSync prints the stack tree after sync -func printTreeForSync(node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo) { +func printTreeForSync(gitClient git.GitClient, node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo) { if node == nil { return } - printTreeVerticalForSync(node, currentBranch, prCache, false) + printTreeVerticalForSync(gitClient, node, currentBranch, prCache, false) } -func printTreeVerticalForSync(node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo, isPipe bool) { +func printTreeVerticalForSync(gitClient git.GitClient, node *stack.TreeNode, currentBranch string, prCache map[string]*github.PRInfo, isPipe bool) { if node == nil { return } @@ -554,7 +557,7 @@ func printTreeVerticalForSync(node *stack.TreeNode, currentBranch string, prCach // Get PR info from cache prInfo := "" - if node.Name != stack.GetBaseBranch() { + if node.Name != stack.GetBaseBranch(gitClient) { if pr, exists := prCache[node.Name]; exists { prInfo = fmt.Sprintf(" [%s :%s]", pr.URL, strings.ToLower(pr.State)) } @@ -570,6 +573,6 @@ func printTreeVerticalForSync(node *stack.TreeNode, currentBranch string, prCach // Print children vertically for _, child := range node.Children { - printTreeVerticalForSync(child, currentBranch, prCache, true) + printTreeVerticalForSync(gitClient, child, currentBranch, prCache, true) } } diff --git a/cmd/sync_test.go b/cmd/sync_test.go new file mode 100644 index 0000000..ec0e979 --- /dev/null +++ b/cmd/sync_test.go @@ -0,0 +1,369 @@ +package cmd + +import ( + "fmt" + "testing" + + "github.com/javoire/stackinator/internal/github" + "github.com/javoire/stackinator/internal/stack" + "github.com/javoire/stackinator/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestRunSyncBasic(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + t.Run("sync simple 2-branch stack", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Setup: Get current branch + mockGit.On("GetCurrentBranch").Return("feature-b", nil) + // Check working tree + mockGit.On("IsWorkingTreeClean").Return(true, nil) + // Get base branch + mockGit.On("GetConfig", "branch.feature-b.stackparent").Return("feature-a") + mockGit.On("GetConfig", "stack.baseBranch").Return("").Maybe() + mockGit.On("GetDefaultBranch").Return("main").Maybe() // Called many times in tree printing + // Get stack chain + stackParents := map[string]string{ + "feature-a": "main", + "feature-b": "feature-a", + } + mockGit.On("GetAllStackParents").Return(stackParents, nil).Maybe() // Called in GetStackChain, TopologicalSort, and displayStatusAfterSync + // Parallel operations + mockGit.On("Fetch").Return(nil) + mockGH.On("GetAllPRs").Return(make(map[string]*github.PRInfo), nil) + // Check if any branches in the current stack are in worktrees + mockGit.On("GetWorktreeBranches").Return(make(map[string]string), nil) + mockGit.On("GetCurrentWorktreePath").Return("/Users/test/repo", nil) + // Get current worktree path + mockGit.On("GetCurrentWorktreePath").Return("/Users/test/repo", nil) + // Get remote branches + mockGit.On("GetRemoteBranchesSet").Return(map[string]bool{ + "main": true, + "feature-a": true, + "feature-b": true, + }) + // Process feature-a + mockGit.On("CheckoutBranch", "feature-a").Return(nil) + mockGit.On("GetCommitHash", "feature-a").Return("abc123", nil) + mockGit.On("GetCommitHash", "origin/feature-a").Return("abc123", nil) + mockGit.On("Rebase", "origin/main").Return(nil) + mockGit.On("FetchBranch", "feature-a").Return(nil) + mockGit.On("Push", "feature-a", true).Return(nil) + // Process feature-b + mockGit.On("CheckoutBranch", "feature-b").Return(nil) + mockGit.On("GetCommitHash", "feature-b").Return("def456", nil) + mockGit.On("GetCommitHash", "origin/feature-b").Return("def456", nil) + mockGit.On("Rebase", "feature-a").Return(nil) + mockGit.On("FetchBranch", "feature-b").Return(nil) + mockGit.On("Push", "feature-b", true).Return(nil) + // Return to original branch + mockGit.On("CheckoutBranch", "feature-b").Return(nil) + + err := runSync(mockGit, mockGH) + + assert.NoError(t, err) + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) + }) +} + +func TestRunSyncMergedParent(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + t.Run("rebase when parent PR is merged", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Setup + mockGit.On("GetCurrentBranch").Return("feature-b", nil) + mockGit.On("IsWorkingTreeClean").Return(true, nil) + mockGit.On("GetConfig", "branch.feature-b.stackparent").Return("feature-a") + mockGit.On("GetConfig", "stack.baseBranch").Return("").Maybe() + mockGit.On("GetDefaultBranch").Return("main").Maybe() // Called many times in tree printing + + stackParents := map[string]string{ + "feature-a": "main", + "feature-b": "feature-a", + } + mockGit.On("GetAllStackParents").Return(stackParents, nil).Maybe() // Called in GetStackChain, TopologicalSort, and displayStatusAfterSync + + // Parallel operations + mockGit.On("Fetch").Return(nil) + + // Parent PR is merged + prCache := map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "MERGED", "main", "Feature A", "url"), + } + mockGH.On("GetAllPRs").Return(prCache, nil) + + mockGit.On("GetWorktreeBranches").Return(make(map[string]string), nil) + mockGit.On("GetCurrentWorktreePath").Return("/Users/test/repo", nil) + mockGit.On("GetRemoteBranchesSet").Return(map[string]bool{ + "main": true, + "feature-a": true, + "feature-b": true, + }) + + // Process feature-a (merged, skip) + mockGit.On("UnsetConfig", "branch.feature-a.stackparent").Return(nil) + + // Process feature-b (parent is merged, update parent to grandparent) + mockGit.On("GetConfig", "branch.feature-a.stackparent").Return("main") + mockGit.On("SetConfig", "branch.feature-b.stackparent", "main").Return(nil) + mockGit.On("CheckoutBranch", "feature-b").Return(nil) + mockGit.On("GetCommitHash", "feature-b").Return("def456", nil) + mockGit.On("GetCommitHash", "origin/feature-b").Return("def456", nil) + mockGit.On("RebaseOnto", "origin/main", "feature-a", "feature-b").Return(nil) + mockGit.On("FetchBranch", "feature-b").Return(nil) + mockGit.On("Push", "feature-b", true).Return(nil) + + // Return to original branch + mockGit.On("CheckoutBranch", "feature-b").Return(nil) + + err := runSync(mockGit, mockGH) + + assert.NoError(t, err) + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) + }) +} + +func TestRunSyncUpdatePRBase(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + t.Run("update PR base when it doesn't match parent", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Setup + mockGit.On("GetCurrentBranch").Return("feature-b", nil) + mockGit.On("IsWorkingTreeClean").Return(true, nil) + mockGit.On("GetConfig", "branch.feature-b.stackparent").Return("feature-a") + mockGit.On("GetConfig", "stack.baseBranch").Return("").Maybe() + mockGit.On("GetDefaultBranch").Return("main").Maybe() // Called many times in tree printing + + stackParents := map[string]string{ + "feature-a": "main", + "feature-b": "feature-a", + } + mockGit.On("GetAllStackParents").Return(stackParents, nil).Maybe() // Called in GetStackChain, TopologicalSort, and displayStatusAfterSync + + // Parallel operations + mockGit.On("Fetch").Return(nil) + + // PRs with mismatched base + prCache := map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "OPEN", "main", "Feature A", "url"), + "feature-b": testutil.NewPRInfo(2, "OPEN", "main", "Feature B", "url"), // Wrong base! + } + mockGH.On("GetAllPRs").Return(prCache, nil) + + mockGit.On("GetWorktreeBranches").Return(make(map[string]string), nil) + mockGit.On("GetCurrentWorktreePath").Return("/Users/test/repo", nil) + mockGit.On("GetRemoteBranchesSet").Return(map[string]bool{ + "main": true, + "feature-a": true, + "feature-b": true, + }) + + // Process feature-a + mockGit.On("CheckoutBranch", "feature-a").Return(nil) + mockGit.On("GetCommitHash", "feature-a").Return("abc123", nil) + mockGit.On("GetCommitHash", "origin/feature-a").Return("abc123", nil) + mockGit.On("Rebase", "origin/main").Return(nil) + mockGit.On("FetchBranch", "feature-a").Return(nil) + mockGit.On("Push", "feature-a", true).Return(nil) + + // Process feature-b + mockGit.On("CheckoutBranch", "feature-b").Return(nil) + mockGit.On("GetCommitHash", "feature-b").Return("def456", nil) + mockGit.On("GetCommitHash", "origin/feature-b").Return("def456", nil) + mockGit.On("Rebase", "feature-a").Return(nil) + mockGit.On("FetchBranch", "feature-b").Return(nil) + mockGit.On("Push", "feature-b", true).Return(nil) + // Update PR base! + mockGH.On("UpdatePRBase", 2, "feature-a").Return(nil) + + // Return to original branch + mockGit.On("CheckoutBranch", "feature-b").Return(nil) + + err := runSync(mockGit, mockGH) + + assert.NoError(t, err) + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) + }) +} + +func TestRunSyncStashHandling(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + t.Run("stash and restore uncommitted changes", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + // Setup + mockGit.On("GetCurrentBranch").Return("feature-a", nil) + // Working tree is dirty + mockGit.On("IsWorkingTreeClean").Return(false, nil) + // Stash changes + mockGit.On("Stash", "stack-sync-autostash").Return(nil) + + mockGit.On("GetConfig", "branch.feature-a.stackparent").Return("main") + mockGit.On("GetConfig", "stack.baseBranch").Return("").Maybe() + mockGit.On("GetDefaultBranch").Return("main").Maybe() + + stackParents := map[string]string{ + "feature-a": "main", + } + mockGit.On("GetAllStackParents").Return(stackParents, nil).Maybe() // Called in GetStackChain, TopologicalSort, and displayStatusAfterSync + + mockGit.On("Fetch").Return(nil) + mockGH.On("GetAllPRs").Return(make(map[string]*github.PRInfo), nil) + + mockGit.On("GetWorktreeBranches").Return(make(map[string]string), nil) + mockGit.On("GetCurrentWorktreePath").Return("/Users/test/repo", nil) + mockGit.On("GetRemoteBranchesSet").Return(map[string]bool{ + "main": true, + "feature-a": true, + }) + + // Process feature-a + mockGit.On("CheckoutBranch", "feature-a").Return(nil) + mockGit.On("GetCommitHash", "feature-a").Return("abc123", nil) + mockGit.On("GetCommitHash", "origin/feature-a").Return("abc123", nil) + mockGit.On("Rebase", "origin/main").Return(nil) + mockGit.On("FetchBranch", "feature-a").Return(nil) + mockGit.On("Push", "feature-a", true).Return(nil) + + mockGit.On("CheckoutBranch", "feature-a").Return(nil) + + // Restore stash + mockGit.On("StashPop").Return(nil) + + err := runSync(mockGit, mockGH) + + assert.NoError(t, err) + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) + }) +} + +func TestRunSyncErrorHandling(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + t.Run("rebase conflict", func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + mockGit.On("GetCurrentBranch").Return("feature-a", nil) + mockGit.On("IsWorkingTreeClean").Return(true, nil) + mockGit.On("GetConfig", "branch.feature-a.stackparent").Return("main") + mockGit.On("GetConfig", "stack.baseBranch").Return("").Maybe() + mockGit.On("GetDefaultBranch").Return("main").Maybe() + + stackParents := map[string]string{ + "feature-a": "main", + } + mockGit.On("GetAllStackParents").Return(stackParents, nil).Maybe() // Called in GetStackChain, TopologicalSort, and displayStatusAfterSync + + mockGit.On("Fetch").Return(nil) + mockGH.On("GetAllPRs").Return(make(map[string]*github.PRInfo), nil) + + mockGit.On("GetWorktreeBranches").Return(make(map[string]string), nil) + mockGit.On("GetCurrentWorktreePath").Return("/Users/test/repo", nil) + mockGit.On("GetRemoteBranchesSet").Return(map[string]bool{ + "main": true, + "feature-a": true, + }) + + mockGit.On("CheckoutBranch", "feature-a").Return(nil) + mockGit.On("GetCommitHash", "feature-a").Return("abc123", nil) + mockGit.On("GetCommitHash", "origin/feature-a").Return("abc123", nil) + // Rebase fails + mockGit.On("Rebase", "origin/main").Return(fmt.Errorf("rebase conflict")) + + err := runSync(mockGit, mockGH) + + assert.Error(t, err) + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) + }) +} + +func TestFilterMergedBranchesForSync(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tree := &stack.TreeNode{ + Name: "main", + Children: []*stack.TreeNode{ + {Name: "feature-a", Children: nil}, + { + Name: "feature-b", + Children: []*stack.TreeNode{ + {Name: "feature-c", Children: nil}, + }, + }, + }, + } + + prCache := map[string]*github.PRInfo{ + "feature-a": testutil.NewPRInfo(1, "MERGED", "main", "Feature A", "url"), + "feature-b": testutil.NewPRInfo(2, "MERGED", "main", "Feature B", "url"), + "feature-c": testutil.NewPRInfo(3, "OPEN", "feature-b", "Feature C", "url"), + } + + filtered := filterMergedBranchesForSync(tree, prCache) + + // feature-a should be filtered out (merged leaf) + // feature-b should be kept (merged but has children) + // feature-c should be kept (not merged) + + assert.Equal(t, "main", filtered.Name) + assert.Len(t, filtered.Children, 1) + assert.Equal(t, "feature-b", filtered.Children[0].Name) + assert.Len(t, filtered.Children[0].Children, 1) + assert.Equal(t, "feature-c", filtered.Children[0].Children[0].Name) +} + +func TestRunSyncNoStackBranches(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + mockGit := new(testutil.MockGitClient) + mockGH := new(testutil.MockGitHubClient) + + mockGit.On("GetCurrentBranch").Return("main", nil) + mockGit.On("IsWorkingTreeClean").Return(true, nil) + mockGit.On("GetConfig", "branch.main.stackparent").Return("") + mockGit.On("GetConfig", "stack.baseBranch").Return("").Maybe() + mockGit.On("GetDefaultBranch").Return("main").Maybe() + + // Empty stack + mockGit.On("GetAllStackParents").Return(make(map[string]string), nil) + + mockGit.On("Fetch").Return(nil) + mockGH.On("GetAllPRs").Return(make(map[string]*github.PRInfo), nil) + + // Even with no stack, we still check worktrees + mockGit.On("GetWorktreeBranches").Return(make(map[string]string), nil) + mockGit.On("GetCurrentWorktreePath").Return("/Users/test/repo", nil) + mockGit.On("GetRemoteBranchesSet").Return(make(map[string]bool)) + mockGit.On("CheckoutBranch", "main").Return(nil) // Return to original branch + + err := runSync(mockGit, mockGH) + + assert.NoError(t, err) + mockGit.AssertExpectations(t) + mockGH.AssertExpectations(t) +} + diff --git a/cmd/worktree.go b/cmd/worktree.go index effcb13..1d850ed 100644 --- a/cmd/worktree.go +++ b/cmd/worktree.go @@ -51,15 +51,18 @@ Use --prune to clean up worktrees for branches with merged PRs.`, return nil }, Run: func(cmd *cobra.Command, args []string) { + gitClient := git.NewGitClient() + githubClient := github.NewGitHubClient() + var err error if worktreePrune { - err = runWorktreePrune() + err = runWorktreePrune(gitClient, githubClient) } else { var baseBranch string if len(args) > 1 { baseBranch = args[1] } - err = runWorktree(args[0], baseBranch) + err = runWorktree(gitClient, githubClient, args[0], baseBranch) } if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) @@ -72,9 +75,9 @@ func init() { worktreeCmd.Flags().BoolVar(&worktreePrune, "prune", false, "Remove worktrees for branches with merged PRs") } -func runWorktree(branchName, baseBranch string) error { +func runWorktree(gitClient git.GitClient, githubClient github.GitHubClient, branchName, baseBranch string) error { // Get repo root - repoRoot, err := git.GetRepoRoot() + repoRoot, err := gitClient.GetRepoRoot() if err != nil { return fmt.Errorf("failed to get repo root: %w", err) } @@ -94,40 +97,40 @@ func runWorktree(branchName, baseBranch string) error { // If base branch is specified, always create new branch from it if baseBranch != "" { - return createNewBranchWorktree(branchName, baseBranch, worktreePath) + return createNewBranchWorktree(gitClient, branchName, baseBranch, worktreePath) } // Check if branch exists locally or on remote - return createWorktreeForExisting(branchName, worktreePath) + return createWorktreeForExisting(gitClient, branchName, worktreePath) } -func createNewBranchWorktree(branchName, baseBranch, worktreePath string) error { +func createNewBranchWorktree(gitClient git.GitClient, branchName, baseBranch, worktreePath string) error { // Check if branch already exists - if git.BranchExists(branchName) { + if gitClient.BranchExists(branchName) { return fmt.Errorf("branch %s already exists", branchName) } // Verify base branch exists (locally or on remote) - if !git.BranchExists(baseBranch) && !git.RemoteBranchExists(baseBranch) { + if !gitClient.BranchExists(baseBranch) && !gitClient.RemoteBranchExists(baseBranch) { return fmt.Errorf("base branch %s does not exist locally or on remote", baseBranch) } // Use origin/baseBranch if it's a remote branch to get fresh copy baseRef := baseBranch - if git.RemoteBranchExists(baseBranch) { + if gitClient.RemoteBranchExists(baseBranch) { baseRef = "origin/" + baseBranch } fmt.Printf("Creating new branch %s from %s\n", branchName, baseRef) // Create worktree with new branch - if err := git.AddWorktreeNewBranch(worktreePath, branchName, baseRef); err != nil { + if err := gitClient.AddWorktreeNewBranch(worktreePath, branchName, baseRef); err != nil { return fmt.Errorf("failed to create worktree: %w", err) } // Set parent in git config for stack tracking configKey := fmt.Sprintf("branch.%s.stackparent", branchName) - if err := git.SetConfig(configKey, baseBranch); err != nil { + if err := gitClient.SetConfig(configKey, baseBranch); err != nil { return fmt.Errorf("failed to set parent config: %w", err) } @@ -140,11 +143,11 @@ func createNewBranchWorktree(branchName, baseBranch, worktreePath string) error return nil } -func createWorktreeForExisting(branchName, worktreePath string) error { +func createWorktreeForExisting(gitClient git.GitClient, branchName, worktreePath string) error { // Check if branch exists locally - if git.BranchExists(branchName) { + if gitClient.BranchExists(branchName) { fmt.Printf("Creating worktree for local branch %s\n", branchName) - if err := git.AddWorktree(worktreePath, branchName); err != nil { + if err := gitClient.AddWorktree(worktreePath, branchName); err != nil { return fmt.Errorf("failed to create worktree: %w", err) } if !dryRun { @@ -155,9 +158,9 @@ func createWorktreeForExisting(branchName, worktreePath string) error { } // Check if branch exists on remote - if git.RemoteBranchExists(branchName) { + if gitClient.RemoteBranchExists(branchName) { fmt.Printf("Creating worktree for remote branch %s\n", branchName) - if err := git.AddWorktreeFromRemote(worktreePath, branchName); err != nil { + if err := gitClient.AddWorktreeFromRemote(worktreePath, branchName); err != nil { return fmt.Errorf("failed to create worktree: %w", err) } if !dryRun { @@ -168,19 +171,19 @@ func createWorktreeForExisting(branchName, worktreePath string) error { } // Branch doesn't exist - create new branch from current branch with stack tracking - currentBranch, err := git.GetCurrentBranch() + currentBranch, err := gitClient.GetCurrentBranch() if err != nil { return fmt.Errorf("failed to get current branch: %w", err) } fmt.Printf("Creating new branch %s from %s\n", branchName, currentBranch) - if err := git.AddWorktreeNewBranch(worktreePath, branchName, currentBranch); err != nil { + if err := gitClient.AddWorktreeNewBranch(worktreePath, branchName, currentBranch); err != nil { return fmt.Errorf("failed to create worktree: %w", err) } // Set parent in git config for stack tracking configKey := fmt.Sprintf("branch.%s.stackparent", branchName) - if err := git.SetConfig(configKey, currentBranch); err != nil { + if err := gitClient.SetConfig(configKey, currentBranch); err != nil { return fmt.Errorf("failed to set parent config: %w", err) } @@ -192,9 +195,9 @@ func createWorktreeForExisting(branchName, worktreePath string) error { return nil } -func runWorktreePrune() error { +func runWorktreePrune(gitClient git.GitClient, githubClient github.GitHubClient) error { // Get repo root - repoRoot, err := git.GetRepoRoot() + repoRoot, err := gitClient.GetRepoRoot() if err != nil { return fmt.Errorf("failed to get repo root: %w", err) } @@ -208,7 +211,7 @@ func runWorktreePrune() error { } // Get all worktrees and their branches - worktreeBranches, err := git.GetWorktreeBranches() + worktreeBranches, err := gitClient.GetWorktreeBranches() if err != nil { return fmt.Errorf("failed to list worktrees: %w", err) } @@ -236,7 +239,7 @@ func runWorktreePrune() error { var prCache map[string]*github.PRInfo if err := spinner.WrapWithSuccess("Fetching PRs...", "Fetched PRs", func() error { var prErr error - prCache, prErr = github.GetAllPRs() + prCache, prErr = githubClient.GetAllPRs() return prErr }); err != nil { return fmt.Errorf("failed to fetch PRs: %w", err) @@ -276,7 +279,7 @@ func runWorktreePrune() error { for i, wt := range mergedWorktrees { fmt.Printf("(%d/%d) Removing worktree for %s...\n", i+1, len(mergedWorktrees), wt.branch) - if err := git.RemoveWorktree(wt.path); err != nil { + if err := gitClient.RemoveWorktree(wt.path); err != nil { fmt.Fprintf(os.Stderr, " Warning: failed to remove worktree: %v\n", err) } else { fmt.Println(" ✓ Removed") diff --git a/go.mod b/go.mod index b8dedb9..422d0d7 100644 --- a/go.mod +++ b/go.mod @@ -2,9 +2,16 @@ module github.com/javoire/stackinator go 1.21 -require github.com/spf13/cobra v1.8.0 +require ( + github.com/spf13/cobra v1.8.0 + github.com/stretchr/testify v1.11.1 +) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d0e8c2c..6f5b07d 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,20 @@ github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/git/git.go b/internal/git/git.go index 4cc5de8..3d61f48 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -13,8 +13,16 @@ var Verbose = false // DryRun controls whether to actually execute mutation commands var DryRun = false +// gitClient implements the GitClient interface using exec.Command +type gitClient struct{} + +// NewGitClient creates a new GitClient implementation +func NewGitClient() GitClient { + return &gitClient{} +} + // runCmd executes a git command and returns stdout -func runCmd(args ...string) (string, error) { +func (c *gitClient) runCmd(args ...string) (string, error) { if Verbose { fmt.Printf(" [git] %s\n", strings.Join(args, " ")) } @@ -32,7 +40,7 @@ func runCmd(args ...string) (string, error) { } // runCmdMayFail runs a command that might fail (returns empty string on error) -func runCmdMayFail(args ...string) string { +func (c *gitClient) runCmdMayFail(args ...string) string { if Verbose { fmt.Printf(" [git] %s\n", strings.Join(args, " ")) } @@ -46,18 +54,18 @@ func runCmdMayFail(args ...string) string { } // GetRepoRoot returns the root directory of the git repository -func GetRepoRoot() (string, error) { - return runCmd("rev-parse", "--show-toplevel") +func (c *gitClient) GetRepoRoot() (string, error) { + return c.runCmd("rev-parse", "--show-toplevel") } // GetCurrentBranch returns the name of the currently checked out branch -func GetCurrentBranch() (string, error) { - return runCmd("branch", "--show-current") +func (c *gitClient) GetCurrentBranch() (string, error) { + return c.runCmd("branch", "--show-current") } // ListBranches returns a list of all local branches -func ListBranches() ([]string, error) { - output, err := runCmd("branch", "--format=%(refname:short)") +func (c *gitClient) ListBranches() ([]string, error) { + output, err := c.runCmd("branch", "--format=%(refname:short)") if err != nil { return nil, err } @@ -71,13 +79,13 @@ func ListBranches() ([]string, error) { } // GetConfig reads a git config value -func GetConfig(key string) string { - return runCmdMayFail("config", "--get", key) +func (c *gitClient) GetConfig(key string) string { + return c.runCmdMayFail("config", "--get", key) } // GetAllStackParents fetches all stack parent configs in one call (more efficient) -func GetAllStackParents() (map[string]string, error) { - output, err := runCmd("config", "--get-regexp", "^branch\\..*\\.stackparent$") +func (c *gitClient) GetAllStackParents() (map[string]string, error) { + output, err := c.runCmd("config", "--get-regexp", "^branch\\..*\\.stackparent$") if err != nil { // No stack parents configured return make(map[string]string), nil @@ -107,89 +115,89 @@ func GetAllStackParents() (map[string]string, error) { } // SetConfig writes a git config value -func SetConfig(key, value string) error { +func (c *gitClient) SetConfig(key, value string) error { if DryRun { fmt.Printf(" [DRY RUN] git config %s %s\n", key, value) return nil } - _, err := runCmd("config", key, value) + _, err := c.runCmd("config", key, value) return err } // UnsetConfig removes a git config value -func UnsetConfig(key string) error { +func (c *gitClient) UnsetConfig(key string) error { if DryRun { fmt.Printf(" [DRY RUN] git config --unset %s\n", key) return nil } - _, err := runCmd("config", "--unset", key) + _, err := c.runCmd("config", "--unset", key) return err } // CreateBranch creates a new branch from the specified base and checks it out -func CreateBranch(name, from string) error { +func (c *gitClient) CreateBranch(name, from string) error { if DryRun { fmt.Printf(" [DRY RUN] git checkout -b %s %s\n", name, from) return nil } - _, err := runCmd("checkout", "-b", name, from) + _, err := c.runCmd("checkout", "-b", name, from) return err } // CheckoutBranch switches to the specified branch -func CheckoutBranch(name string) error { +func (c *gitClient) CheckoutBranch(name string) error { if DryRun { fmt.Printf(" [DRY RUN] git checkout %s\n", name) return nil } - _, err := runCmd("checkout", name) + _, err := c.runCmd("checkout", name) return err } // RenameBranch renames a branch (must be on that branch) -func RenameBranch(oldName, newName string) error { +func (c *gitClient) RenameBranch(oldName, newName string) error { if DryRun { fmt.Printf(" [DRY RUN] git branch -m %s %s\n", oldName, newName) return nil } - _, err := runCmd("branch", "-m", oldName, newName) + _, err := c.runCmd("branch", "-m", oldName, newName) return err } // Rebase rebases the current branch onto the specified base -func Rebase(onto string) error { +func (c *gitClient) Rebase(onto string) error { if DryRun { fmt.Printf(" [DRY RUN] git rebase --autostash %s\n", onto) return nil } - _, err := runCmd("rebase", "--autostash", onto) + _, err := c.runCmd("rebase", "--autostash", onto) return err } // RebaseOnto rebases the current branch onto newBase, excluding commits up to and including oldBase // This is useful for handling squash merges where oldBase was squashed into newBase // Equivalent to: git rebase --onto newBase oldBase currentBranch -func RebaseOnto(newBase, oldBase, currentBranch string) error { +func (c *gitClient) RebaseOnto(newBase, oldBase, currentBranch string) error { if DryRun { fmt.Printf(" [DRY RUN] git rebase --autostash --onto %s %s %s\n", newBase, oldBase, currentBranch) return nil } - _, err := runCmd("rebase", "--autostash", "--onto", newBase, oldBase, currentBranch) + _, err := c.runCmd("rebase", "--autostash", "--onto", newBase, oldBase, currentBranch) return err } // FetchBranch fetches a specific branch from origin to update tracking info -func FetchBranch(branch string) error { +func (c *gitClient) FetchBranch(branch string) error { if DryRun { fmt.Printf(" [DRY RUN] git fetch origin %s\n", branch) return nil } - _, err := runCmd("fetch", "origin", branch) + _, err := c.runCmd("fetch", "origin", branch) return err } // Push pushes a branch to origin -func Push(branch string, forceWithLease bool) error { +func (c *gitClient) Push(branch string, forceWithLease bool) error { args := []string{"push"} if forceWithLease { args = append(args, "--force-with-lease") @@ -201,12 +209,12 @@ func Push(branch string, forceWithLease bool) error { return nil } - _, err := runCmd(args...) + _, err := c.runCmd(args...) return err } // ForcePush force pushes a branch to origin (bypasses --force-with-lease safety) -func ForcePush(branch string) error { +func (c *gitClient) ForcePush(branch string) error { args := []string{"push", "--force", "origin", branch} if DryRun { @@ -214,13 +222,13 @@ func ForcePush(branch string) error { return nil } - _, err := runCmd(args...) + _, err := c.runCmd(args...) return err } // IsWorkingTreeClean returns true if there are no uncommitted changes -func IsWorkingTreeClean() (bool, error) { - output, err := runCmd("status", "--porcelain") +func (c *gitClient) IsWorkingTreeClean() (bool, error) { + output, err := c.runCmd("status", "--porcelain") if err != nil { return false, err } @@ -228,32 +236,32 @@ func IsWorkingTreeClean() (bool, error) { } // Fetch fetches from origin -func Fetch() error { +func (c *gitClient) Fetch() error { if DryRun { fmt.Printf(" [DRY RUN] git fetch origin\n") return nil } - _, err := runCmd("fetch", "origin") + _, err := c.runCmd("fetch", "origin") return err } // BranchExists checks if a branch exists locally -func BranchExists(name string) bool { - output := runCmdMayFail("rev-parse", "--verify", "refs/heads/"+name) +func (c *gitClient) BranchExists(name string) bool { + output := c.runCmdMayFail("rev-parse", "--verify", "refs/heads/"+name) return output != "" } // RemoteBranchExists checks if a branch exists on origin -func RemoteBranchExists(name string) bool { - output := runCmdMayFail("rev-parse", "--verify", "refs/remotes/origin/"+name) +func (c *gitClient) RemoteBranchExists(name string) bool { + output := c.runCmdMayFail("rev-parse", "--verify", "refs/remotes/origin/"+name) return output != "" } // GetRemoteBranchesSet fetches all remote branches from origin in one call // and returns a set (map[string]bool) for efficient lookups. // This is more efficient than calling RemoteBranchExists multiple times. -func GetRemoteBranchesSet() map[string]bool { - output := runCmdMayFail("for-each-ref", "--format=%(refname:short)", "refs/remotes/origin/") +func (c *gitClient) GetRemoteBranchesSet() map[string]bool { + output := c.runCmdMayFail("for-each-ref", "--format=%(refname:short)", "refs/remotes/origin/") if output == "" { return make(map[string]bool) } @@ -275,61 +283,61 @@ func GetRemoteBranchesSet() map[string]bool { } // AbortRebase aborts an in-progress rebase -func AbortRebase() error { +func (c *gitClient) AbortRebase() error { if DryRun { fmt.Printf(" [DRY RUN] git rebase --abort\n") return nil } - _, err := runCmd("rebase", "--abort") + _, err := c.runCmd("rebase", "--abort") return err } // ResetToRemote resets the current branch to match the remote branch exactly -func ResetToRemote(branch string) error { +func (c *gitClient) ResetToRemote(branch string) error { remoteBranch := "origin/" + branch if DryRun { fmt.Printf(" [DRY RUN] git reset --hard %s\n", remoteBranch) return nil } - _, err := runCmd("reset", "--hard", remoteBranch) + _, err := c.runCmd("reset", "--hard", remoteBranch) return err } // GetMergeBase returns the common ancestor of two branches -func GetMergeBase(branch1, branch2 string) (string, error) { - return runCmd("merge-base", branch1, branch2) +func (c *gitClient) GetMergeBase(branch1, branch2 string) (string, error) { + return c.runCmd("merge-base", branch1, branch2) } // GetCommitHash returns the commit hash of a ref -func GetCommitHash(ref string) (string, error) { - return runCmd("rev-parse", ref) +func (c *gitClient) GetCommitHash(ref string) (string, error) { + return c.runCmd("rev-parse", ref) } // Stash stashes the current changes -func Stash(message string) error { +func (c *gitClient) Stash(message string) error { if DryRun { fmt.Printf(" [DRY RUN] git stash push -m \"%s\"\n", message) return nil } - _, err := runCmd("stash", "push", "-m", message) + _, err := c.runCmd("stash", "push", "-m", message) return err } // StashPop pops the most recent stash -func StashPop() error { +func (c *gitClient) StashPop() error { if DryRun { fmt.Printf(" [DRY RUN] git stash pop\n") return nil } - _, err := runCmd("stash", "pop") + _, err := c.runCmd("stash", "pop") return err } // GetDefaultBranch attempts to detect the repository's default branch // by checking the remote HEAD or falling back to common defaults -func GetDefaultBranch() string { +func (c *gitClient) GetDefaultBranch() string { // Try to get the remote's default branch - output := runCmdMayFail("symbolic-ref", "refs/remotes/origin/HEAD") + output := c.runCmdMayFail("symbolic-ref", "refs/remotes/origin/HEAD") if output != "" { // Output format: refs/remotes/origin/master parts := strings.Split(output, "/") @@ -340,7 +348,7 @@ func GetDefaultBranch() string { // Fall back to checking which common branch exists for _, branch := range []string{"master", "main"} { - if BranchExists(branch) { + if c.BranchExists(branch) { return branch } } @@ -350,8 +358,8 @@ func GetDefaultBranch() string { } // GetWorktreeBranches returns a map of branch names to their worktree paths (resolved to canonical paths) -func GetWorktreeBranches() (map[string]string, error) { - output := runCmdMayFail("worktree", "list", "--porcelain") +func (c *gitClient) GetWorktreeBranches() (map[string]string, error) { + output := c.runCmdMayFail("worktree", "list", "--porcelain") if output == "" { return make(map[string]string), nil } @@ -380,9 +388,9 @@ func GetWorktreeBranches() (map[string]string, error) { } // GetCurrentWorktreePath returns the absolute path of the current worktree -func GetCurrentWorktreePath() (string, error) { +func (c *gitClient) GetCurrentWorktreePath() (string, error) { // Use git rev-parse to get the absolute path to the top-level of the current worktree - path, err := runCmd("rev-parse", "--path-format=absolute", "--show-toplevel") + path, err := c.runCmd("rev-parse", "--path-format=absolute", "--show-toplevel") if err != nil { return "", err } @@ -419,7 +427,7 @@ func resolveSymlinks(path string) (string, error) { } // IsCommitsBehind checks if the 'branch' is behind 'base' (i.e., base has commits that branch doesn't) -func IsCommitsBehind(branch, base string) (bool, error) { +func (c *gitClient) IsCommitsBehind(branch, base string) (bool, error) { // NOTE: Caller should fetch first to ensure latest remote refs // We don't fetch here to avoid multiple fetches in loops @@ -429,7 +437,7 @@ func IsCommitsBehind(branch, base string) (bool, error) { // Get commit count: ahead...behind // Format: "aheadbehind" - output, err := runCmd("rev-list", "--left-right", "--count", branch+"..."+baseBranch) + output, err := c.runCmd("rev-list", "--left-right", "--count", branch+"..."+baseBranch) if err != nil { return false, err } @@ -446,71 +454,71 @@ func IsCommitsBehind(branch, base string) (bool, error) { // DeleteBranch deletes a branch safely (equivalent to git branch -d) // This will fail if the branch has unmerged commits -func DeleteBranch(name string) error { +func (c *gitClient) DeleteBranch(name string) error { if DryRun { fmt.Printf(" [DRY RUN] git branch -d %s\n", name) return nil } - _, err := runCmd("branch", "-d", name) + _, err := c.runCmd("branch", "-d", name) return err } // DeleteBranchForce force deletes a branch (equivalent to git branch -D) // This will delete the branch even if it has unmerged commits -func DeleteBranchForce(name string) error { +func (c *gitClient) DeleteBranchForce(name string) error { if DryRun { fmt.Printf(" [DRY RUN] git branch -D %s\n", name) return nil } - _, err := runCmd("branch", "-D", name) + _, err := c.runCmd("branch", "-D", name) return err } // AddWorktree creates a worktree at the specified path for an existing local branch -func AddWorktree(path, branch string) error { +func (c *gitClient) AddWorktree(path, branch string) error { if DryRun { fmt.Printf(" [DRY RUN] git worktree add %s %s\n", path, branch) return nil } - _, err := runCmd("worktree", "add", path, branch) + _, err := c.runCmd("worktree", "add", path, branch) return err } // AddWorktreeNewBranch creates a worktree with a new branch at the specified path // The new branch is created from the given base branch -func AddWorktreeNewBranch(path, newBranch, baseBranch string) error { +func (c *gitClient) AddWorktreeNewBranch(path, newBranch, baseBranch string) error { if DryRun { fmt.Printf(" [DRY RUN] git worktree add -b %s %s %s\n", newBranch, path, baseBranch) return nil } - _, err := runCmd("worktree", "add", "-b", newBranch, path, baseBranch) + _, err := c.runCmd("worktree", "add", "-b", newBranch, path, baseBranch) return err } // AddWorktreeFromRemote creates a worktree tracking a remote branch // This creates a local branch that tracks the remote branch -func AddWorktreeFromRemote(path, branch string) error { +func (c *gitClient) AddWorktreeFromRemote(path, branch string) error { if DryRun { fmt.Printf(" [DRY RUN] git worktree add --track -b %s %s origin/%s\n", branch, path, branch) return nil } - _, err := runCmd("worktree", "add", "--track", "-b", branch, path, "origin/"+branch) + _, err := c.runCmd("worktree", "add", "--track", "-b", branch, path, "origin/"+branch) return err } // RemoveWorktree removes a worktree at the specified path -func RemoveWorktree(path string) error { +func (c *gitClient) RemoveWorktree(path string) error { if DryRun { fmt.Printf(" [DRY RUN] git worktree remove %s\n", path) return nil } - _, err := runCmd("worktree", "remove", path) + _, err := c.runCmd("worktree", "remove", path) return err } // ListWorktrees returns a list of all worktree paths -func ListWorktrees() ([]string, error) { - output := runCmdMayFail("worktree", "list", "--porcelain") +func (c *gitClient) ListWorktrees() ([]string, error) { + output := c.runCmdMayFail("worktree", "list", "--porcelain") if output == "" { return []string{}, nil } diff --git a/internal/git/git_test.go b/internal/git/git_test.go new file mode 100644 index 0000000..7f6cfe9 --- /dev/null +++ b/internal/git/git_test.go @@ -0,0 +1,22 @@ +package git + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewGitClient(t *testing.T) { + client := NewGitClient() + assert.NotNil(t, client) +} + +func TestGitClientInterface(t *testing.T) { + // Verify that gitClient implements GitClient interface + var _ GitClient = &gitClient{} +} + +// Note: More comprehensive tests would require mocking exec.Command or running actual git commands +// For unit tests focused on critical path, we rely on integration tests or testutil mocks +// The real value is in testing the stack package and command packages with mocked clients + diff --git a/internal/git/interface.go b/internal/git/interface.go new file mode 100644 index 0000000..42df7c2 --- /dev/null +++ b/internal/git/interface.go @@ -0,0 +1,43 @@ +package git + +// GitClient defines the interface for all git operations +type GitClient interface { + GetRepoRoot() (string, error) + GetCurrentBranch() (string, error) + ListBranches() ([]string, error) + GetConfig(key string) string + GetAllStackParents() (map[string]string, error) + SetConfig(key, value string) error + UnsetConfig(key string) error + CreateBranch(name, from string) error + CheckoutBranch(name string) error + RenameBranch(oldName, newName string) error + Rebase(onto string) error + RebaseOnto(newBase, oldBase, currentBranch string) error + FetchBranch(branch string) error + Push(branch string, forceWithLease bool) error + ForcePush(branch string) error + IsWorkingTreeClean() (bool, error) + Fetch() error + BranchExists(name string) bool + RemoteBranchExists(name string) bool + GetRemoteBranchesSet() map[string]bool + AbortRebase() error + ResetToRemote(branch string) error + GetMergeBase(branch1, branch2 string) (string, error) + GetCommitHash(ref string) (string, error) + Stash(message string) error + StashPop() error + GetDefaultBranch() string + GetWorktreeBranches() (map[string]string, error) + GetCurrentWorktreePath() (string, error) + IsCommitsBehind(branch, base string) (bool, error) + DeleteBranch(name string) error + DeleteBranchForce(name string) error + AddWorktree(path, branch string) error + AddWorktreeNewBranch(path, newBranch, baseBranch string) error + AddWorktreeFromRemote(path, branch string) error + RemoveWorktree(path string) error + ListWorktrees() ([]string, error) +} + diff --git a/internal/github/github.go b/internal/github/github.go index cc850d3..60a5cfd 100644 --- a/internal/github/github.go +++ b/internal/github/github.go @@ -25,8 +25,16 @@ type PRInfo struct { MergeStateStatus string // "BEHIND", "BLOCKED", "CLEAN", "DIRTY", "UNKNOWN", "UNSTABLE" } +// githubClient implements the GitHubClient interface using exec.Command +type githubClient struct{} + +// NewGitHubClient creates a new GitHubClient implementation +func NewGitHubClient() GitHubClient { + return &githubClient{} +} + // runGH executes a gh CLI command and returns stdout -func runGH(args ...string) (string, error) { +func (c *githubClient) runGH(args ...string) (string, error) { if Verbose { fmt.Printf(" [gh] %s\n", strings.Join(args, " ")) } @@ -44,8 +52,8 @@ func runGH(args ...string) (string, error) { } // GetPRForBranch returns PR info for the specified branch -func GetPRForBranch(branch string) (*PRInfo, error) { - output, err := runGH("pr", "view", branch, "--json", "number,state,baseRefName,title,url,mergeStateStatus") +func (c *githubClient) GetPRForBranch(branch string) (*PRInfo, error) { + output, err := c.runGH("pr", "view", branch, "--json", "number,state,baseRefName,title,url,mergeStateStatus") if err != nil { // No PR exists for this branch return nil, nil @@ -75,9 +83,9 @@ func GetPRForBranch(branch string) (*PRInfo, error) { } // GetAllPRs fetches all PRs for the repository in a single call -func GetAllPRs() (map[string]*PRInfo, error) { +func (c *githubClient) GetAllPRs() (map[string]*PRInfo, error) { // Fetch all PRs (open, closed, and merged) in one call - output, err := runGH("pr", "list", "--state", "all", "--json", "number,state,headRefName,baseRefName,title,url,mergeStateStatus", "--limit", "1000") + output, err := c.runGH("pr", "list", "--state", "all", "--json", "number,state,headRefName,baseRefName,title,url,mergeStateStatus", "--limit", "1000") if err != nil { return nil, fmt.Errorf("failed to list PRs: %w", err) } @@ -113,19 +121,19 @@ func GetAllPRs() (map[string]*PRInfo, error) { } // UpdatePRBase updates the base branch of a PR -func UpdatePRBase(prNumber int, newBase string) error { +func (c *githubClient) UpdatePRBase(prNumber int, newBase string) error { if DryRun { fmt.Printf(" [DRY RUN] gh pr edit %d --base %s\n", prNumber, newBase) return nil } - _, err := runGH("pr", "edit", strconv.Itoa(prNumber), "--base", newBase) + _, err := c.runGH("pr", "edit", strconv.Itoa(prNumber), "--base", newBase) return err } // IsPRMerged checks if a PR has been merged -func IsPRMerged(prNumber int) (bool, error) { - output, err := runGH("pr", "view", strconv.Itoa(prNumber), "--json", "state") +func (c *githubClient) IsPRMerged(prNumber int) (bool, error) { + output, err := c.runGH("pr", "view", strconv.Itoa(prNumber), "--json", "state") if err != nil { return false, err } @@ -140,5 +148,3 @@ func IsPRMerged(prNumber int) (bool, error) { return data.State == "MERGED", nil } - - diff --git a/internal/github/github_test.go b/internal/github/github_test.go new file mode 100644 index 0000000..bc1fd9d --- /dev/null +++ b/internal/github/github_test.go @@ -0,0 +1,35 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPRInfoParsing(t *testing.T) { + // Test that PRInfo struct is properly defined + pr := &PRInfo{ + Number: 123, + State: "OPEN", + Base: "main", + Title: "Test PR", + URL: "https://github.com/test/repo/pull/123", + MergeStateStatus: "CLEAN", + } + + assert.Equal(t, 123, pr.Number) + assert.Equal(t, "OPEN", pr.State) + assert.Equal(t, "main", pr.Base) + assert.Equal(t, "Test PR", pr.Title) + assert.Equal(t, "https://github.com/test/repo/pull/123", pr.URL) + assert.Equal(t, "CLEAN", pr.MergeStateStatus) +} + +func TestNewGitHubClient(t *testing.T) { + client := NewGitHubClient() + assert.NotNil(t, client) +} + +// Note: More comprehensive tests would require mocking exec.Command or running actual gh CLI commands +// For unit tests focused on critical path, we rely on integration tests or testutil mocks + diff --git a/internal/github/interface.go b/internal/github/interface.go new file mode 100644 index 0000000..5bb6c7f --- /dev/null +++ b/internal/github/interface.go @@ -0,0 +1,10 @@ +package github + +// GitHubClient defines the interface for all GitHub operations +type GitHubClient interface { + GetPRForBranch(branch string) (*PRInfo, error) + GetAllPRs() (map[string]*PRInfo, error) + UpdatePRBase(prNumber int, newBase string) error + IsPRMerged(prNumber int) (bool, error) +} + diff --git a/internal/stack/stack.go b/internal/stack/stack.go index 48dad0c..017f522 100644 --- a/internal/stack/stack.go +++ b/internal/stack/stack.go @@ -15,9 +15,9 @@ type StackBranch struct { } // GetStackBranches returns all branches that are part of a stack -func GetStackBranches() ([]StackBranch, error) { +func GetStackBranches(gitClient git.GitClient) ([]StackBranch, error) { // Fetch all stack parents in one efficient call - parents, err := git.GetAllStackParents() + parents, err := gitClient.GetAllStackParents() if err != nil { return nil, fmt.Errorf("failed to get stack parents: %w", err) } @@ -35,8 +35,8 @@ func GetStackBranches() ([]StackBranch, error) { } // GetChildrenOf returns all direct children of the specified branch -func GetChildrenOf(branch string) ([]StackBranch, error) { - allBranches, err := GetStackBranches() +func GetChildrenOf(gitClient git.GitClient, branch string) ([]StackBranch, error) { + allBranches, err := GetStackBranches(gitClient) if err != nil { return nil, err } @@ -57,9 +57,9 @@ func GetChildrenOf(branch string) ([]StackBranch, error) { } // GetStackChain returns the chain from the base to the specified branch -func GetStackChain(branch string) ([]string, error) { +func GetStackChain(gitClient git.GitClient, branch string) ([]string, error) { // Get all parents at once for efficiency - parents, err := git.GetAllStackParents() + parents, err := gitClient.GetAllStackParents() if err != nil { return nil, err } @@ -149,17 +149,17 @@ func TopologicalSort(branches []StackBranch) ([]StackBranch, error) { } // GetBaseBranch returns the configured base branch or auto-detects it -func GetBaseBranch() string { - base := git.GetConfig("stack.baseBranch") +func GetBaseBranch(gitClient git.GitClient) string { + base := gitClient.GetConfig("stack.baseBranch") if base == "" { - return git.GetDefaultBranch() + return gitClient.GetDefaultBranch() } return base } // BuildStackTree builds a tree representation for display -func BuildStackTree() (*TreeNode, error) { - stackBranches, err := GetStackBranches() +func BuildStackTree(gitClient git.GitClient) (*TreeNode, error) { + stackBranches, err := GetStackBranches(gitClient) if err != nil { return nil, err } @@ -178,14 +178,14 @@ func BuildStackTree() (*TreeNode, error) { } // Build tree starting from base branch - baseBranch := GetBaseBranch() + baseBranch := GetBaseBranch(gitClient) return buildTreeNode(baseBranch, childrenMap), nil } // BuildStackTreeForBranch builds a tree for only the stack containing the specified branch -func BuildStackTreeForBranch(branchName string) (*TreeNode, error) { +func BuildStackTreeForBranch(gitClient git.GitClient, branchName string) (*TreeNode, error) { // Get the chain from base to current branch - chain, err := GetStackChain(branchName) + chain, err := GetStackChain(gitClient, branchName) if err != nil { return nil, err } @@ -201,7 +201,7 @@ func BuildStackTreeForBranch(branchName string) (*TreeNode, error) { } // Get all stack branches to find children - stackBranches, err := GetStackBranches() + stackBranches, err := GetStackBranches(gitClient) if err != nil { return nil, err } @@ -227,7 +227,7 @@ func BuildStackTreeForBranch(branchName string) (*TreeNode, error) { // If the root is not the base branch, we need to include the base branch // in the tree as the actual root - baseBranch := GetBaseBranch() + baseBranch := GetBaseBranch(gitClient) if root != baseBranch { // Check if the root has a parent in childrenMap (meaning there are branches // that have root as their parent) diff --git a/internal/stack/stack_test.go b/internal/stack/stack_test.go new file mode 100644 index 0000000..a535366 --- /dev/null +++ b/internal/stack/stack_test.go @@ -0,0 +1,333 @@ +package stack + +import ( + "testing" + + "github.com/javoire/stackinator/internal/testutil" + "github.com/stretchr/testify/assert" +) + +func TestGetStackBranches(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + stackParents map[string]string + expectedBranches []string + expectError bool + }{ + { + name: "simple stack", + stackParents: map[string]string{ + "feature-a": "main", + "feature-b": "feature-a", + }, + expectedBranches: []string{"feature-a", "feature-b"}, + expectError: false, + }, + { + name: "no stack branches", + stackParents: map[string]string{}, + expectedBranches: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGit.On("GetAllStackParents").Return(tt.stackParents, nil) + + branches, err := GetStackBranches(mockGit) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, branches, len(tt.expectedBranches)) + + branchNames := make(map[string]bool) + for _, b := range branches { + branchNames[b.Name] = true + } + + for _, expectedName := range tt.expectedBranches { + assert.True(t, branchNames[expectedName], "Expected branch %s not found", expectedName) + } + } + + mockGit.AssertExpectations(t) + }) + } +} + +func TestGetStackChain(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + branch string + stackParents map[string]string + expectedChain []string + expectError bool + }{ + { + name: "simple chain", + branch: "feature-c", + stackParents: map[string]string{ + "feature-a": "main", + "feature-b": "feature-a", + "feature-c": "feature-b", + }, + expectedChain: []string{"main", "feature-a", "feature-b", "feature-c"}, // Includes base + expectError: false, + }, + { + name: "single branch", + branch: "feature-a", + stackParents: map[string]string{ + "feature-a": "main", + }, + expectedChain: []string{"main", "feature-a"}, // Includes base + expectError: false, + }, + { + name: "circular dependency", + branch: "feature-b", + stackParents: map[string]string{ + "feature-a": "feature-b", + "feature-b": "feature-a", + }, + expectedChain: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGit.On("GetAllStackParents").Return(tt.stackParents, nil) + + chain, err := GetStackChain(mockGit, tt.branch) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedChain, chain) + } + + mockGit.AssertExpectations(t) + }) + } +} + +func TestGetBaseBranch(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + configuredBase string + defaultBranch string + expectedBase string + }{ + { + name: "configured base branch", + configuredBase: "develop", + defaultBranch: "main", + expectedBase: "develop", + }, + { + name: "no configured base, use default", + configuredBase: "", + defaultBranch: "main", + expectedBase: "main", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGit.On("GetConfig", "stack.baseBranch").Return(tt.configuredBase) + if tt.configuredBase == "" { + mockGit.On("GetDefaultBranch").Return(tt.defaultBranch) + } + + base := GetBaseBranch(mockGit) + + assert.Equal(t, tt.expectedBase, base) + mockGit.AssertExpectations(t) + }) + } +} + +func TestBuildStackTree(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + stackParents map[string]string + baseBranch string + expectedRootName string + expectError bool + }{ + { + name: "simple tree", + stackParents: map[string]string{ + "feature-a": "main", + "feature-b": "main", + "feature-c": "feature-a", + }, + baseBranch: "main", + expectedRootName: "main", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockGit := new(testutil.MockGitClient) + mockGit.On("GetAllStackParents").Return(tt.stackParents, nil) + mockGit.On("GetConfig", "stack.baseBranch").Return("") + mockGit.On("GetDefaultBranch").Return(tt.baseBranch) + + tree, err := BuildStackTree(mockGit) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, tree) + assert.Equal(t, tt.expectedRootName, tree.Name) + } + + mockGit.AssertExpectations(t) + }) + } +} + +func TestTopologicalSort(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + tests := []struct { + name string + branches []StackBranch + expectedOrder []string // Names of branches in expected order + expectError bool + }{ + { + name: "simple linear stack", + branches: []StackBranch{ + {Name: "feature-c", Parent: "feature-b"}, + {Name: "feature-a", Parent: "main"}, + {Name: "feature-b", Parent: "feature-a"}, + }, + expectedOrder: []string{"feature-a", "feature-b", "feature-c"}, + expectError: false, + }, + { + name: "branches with shared parent", + branches: []StackBranch{ + {Name: "feature-a", Parent: "main"}, + {Name: "feature-b", Parent: "main"}, + }, + expectedOrder: []string{"feature-a", "feature-b"}, // Alphabetical within same level + expectError: false, + }, + { + name: "circular dependency", + branches: []StackBranch{ + {Name: "feature-a", Parent: "feature-b"}, + {Name: "feature-b", Parent: "feature-a"}, + }, + expectedOrder: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sorted, err := TopologicalSort(tt.branches) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, sorted, len(tt.expectedOrder)) + + // Check order + for i, expectedName := range tt.expectedOrder { + assert.Equal(t, expectedName, sorted[i].Name, "Branch at position %d should be %s", i, expectedName) + } + } + }) + } +} + +func TestBuildStackTreeForBranch(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + mockGit := new(testutil.MockGitClient) + stackParents := map[string]string{ + "feature-a": "main", + "feature-b": "feature-a", + "feature-c": "feature-b", + "other-branch": "main", + } + + // Mock all the calls + mockGit.On("GetAllStackParents").Return(stackParents, nil).Times(2) // Called twice in the function + mockGit.On("GetConfig", "stack.baseBranch").Return("") + mockGit.On("GetDefaultBranch").Return("main") + + // Build tree for feature-c - should only include its chain + tree, err := BuildStackTreeForBranch(mockGit, "feature-c") + + assert.NoError(t, err) + assert.NotNil(t, tree) + assert.Equal(t, "main", tree.Name) + + // Verify the tree structure: main -> feature-a -> feature-b -> feature-c + // but NOT other-branch + assert.Len(t, tree.Children, 1) + assert.Equal(t, "feature-a", tree.Children[0].Name) + assert.Len(t, tree.Children[0].Children, 1) + assert.Equal(t, "feature-b", tree.Children[0].Children[0].Name) + assert.Len(t, tree.Children[0].Children[0].Children, 1) + assert.Equal(t, "feature-c", tree.Children[0].Children[0].Children[0].Name) + + mockGit.AssertExpectations(t) +} + +func TestGetChildrenOf(t *testing.T) { + testutil.SetupTest() + defer testutil.TeardownTest() + + mockGit := new(testutil.MockGitClient) + stackParents := map[string]string{ + "feature-a": "main", + "feature-b": "main", + "feature-c": "feature-a", + } + + mockGit.On("GetAllStackParents").Return(stackParents, nil) + + children, err := GetChildrenOf(mockGit, "main") + + assert.NoError(t, err) + assert.Len(t, children, 2) + + // Should be sorted alphabetically + names := []string{children[0].Name, children[1].Name} + assert.Contains(t, names, "feature-a") + assert.Contains(t, names, "feature-b") + + mockGit.AssertExpectations(t) +} + diff --git a/internal/testutil/fixtures.go b/internal/testutil/fixtures.go new file mode 100644 index 0000000..e5ffdbb --- /dev/null +++ b/internal/testutil/fixtures.go @@ -0,0 +1,40 @@ +package testutil + +import "github.com/javoire/stackinator/internal/github" + +// BuildStackParents creates a map of branch names to their stack parents for testing +func BuildStackParents(config map[string]string) map[string]string { + return config +} + +// CreatePRMap creates a map of branch names to PR info for testing +func CreatePRMap(prs map[string]*github.PRInfo) map[string]*github.PRInfo { + return prs +} + +// NewPRInfo creates a PR info struct for testing +func NewPRInfo(number int, state, base, title, url string) *github.PRInfo { + return &github.PRInfo{ + Number: number, + State: state, + Base: base, + Title: title, + URL: url, + MergeStateStatus: "CLEAN", + } +} + +// BuildGitConfig simulates git config output +func BuildGitConfig(configs map[string]string) map[string]string { + return configs +} + +// BuildRemoteBranchesSet creates a set of remote branches for testing +func BuildRemoteBranchesSet(branches []string) map[string]bool { + set := make(map[string]bool) + for _, branch := range branches { + set[branch] = true + } + return set +} + diff --git a/internal/testutil/mocks.go b/internal/testutil/mocks.go new file mode 100644 index 0000000..f24ea36 --- /dev/null +++ b/internal/testutil/mocks.go @@ -0,0 +1,225 @@ +package testutil + +import ( + "github.com/javoire/stackinator/internal/github" + "github.com/stretchr/testify/mock" +) + +// MockGitClient is a mock implementation of git.GitClient for testing +type MockGitClient struct { + mock.Mock +} + +func (m *MockGitClient) GetRepoRoot() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) +} + +func (m *MockGitClient) GetCurrentBranch() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) +} + +func (m *MockGitClient) ListBranches() ([]string, error) { + args := m.Called() + return args.Get(0).([]string), args.Error(1) +} + +func (m *MockGitClient) GetConfig(key string) string { + args := m.Called(key) + return args.String(0) +} + +func (m *MockGitClient) GetAllStackParents() (map[string]string, error) { + args := m.Called() + return args.Get(0).(map[string]string), args.Error(1) +} + +func (m *MockGitClient) SetConfig(key, value string) error { + args := m.Called(key, value) + return args.Error(0) +} + +func (m *MockGitClient) UnsetConfig(key string) error { + args := m.Called(key) + return args.Error(0) +} + +func (m *MockGitClient) CreateBranch(name, from string) error { + args := m.Called(name, from) + return args.Error(0) +} + +func (m *MockGitClient) CheckoutBranch(name string) error { + args := m.Called(name) + return args.Error(0) +} + +func (m *MockGitClient) RenameBranch(oldName, newName string) error { + args := m.Called(oldName, newName) + return args.Error(0) +} + +func (m *MockGitClient) Rebase(onto string) error { + args := m.Called(onto) + return args.Error(0) +} + +func (m *MockGitClient) RebaseOnto(newBase, oldBase, currentBranch string) error { + args := m.Called(newBase, oldBase, currentBranch) + return args.Error(0) +} + +func (m *MockGitClient) FetchBranch(branch string) error { + args := m.Called(branch) + return args.Error(0) +} + +func (m *MockGitClient) Push(branch string, forceWithLease bool) error { + args := m.Called(branch, forceWithLease) + return args.Error(0) +} + +func (m *MockGitClient) ForcePush(branch string) error { + args := m.Called(branch) + return args.Error(0) +} + +func (m *MockGitClient) IsWorkingTreeClean() (bool, error) { + args := m.Called() + return args.Bool(0), args.Error(1) +} + +func (m *MockGitClient) Fetch() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockGitClient) BranchExists(name string) bool { + args := m.Called(name) + return args.Bool(0) +} + +func (m *MockGitClient) RemoteBranchExists(name string) bool { + args := m.Called(name) + return args.Bool(0) +} + +func (m *MockGitClient) GetRemoteBranchesSet() map[string]bool { + args := m.Called() + return args.Get(0).(map[string]bool) +} + +func (m *MockGitClient) AbortRebase() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockGitClient) ResetToRemote(branch string) error { + args := m.Called(branch) + return args.Error(0) +} + +func (m *MockGitClient) GetMergeBase(branch1, branch2 string) (string, error) { + args := m.Called(branch1, branch2) + return args.String(0), args.Error(1) +} + +func (m *MockGitClient) GetCommitHash(ref string) (string, error) { + args := m.Called(ref) + return args.String(0), args.Error(1) +} + +func (m *MockGitClient) Stash(message string) error { + args := m.Called(message) + return args.Error(0) +} + +func (m *MockGitClient) StashPop() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockGitClient) GetDefaultBranch() string { + args := m.Called() + return args.String(0) +} + +func (m *MockGitClient) GetWorktreeBranches() (map[string]string, error) { + args := m.Called() + return args.Get(0).(map[string]string), args.Error(1) +} + +func (m *MockGitClient) GetCurrentWorktreePath() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) +} + +func (m *MockGitClient) IsCommitsBehind(branch, base string) (bool, error) { + args := m.Called(branch, base) + return args.Bool(0), args.Error(1) +} + +func (m *MockGitClient) DeleteBranch(name string) error { + args := m.Called(name) + return args.Error(0) +} + +func (m *MockGitClient) DeleteBranchForce(name string) error { + args := m.Called(name) + return args.Error(0) +} + +func (m *MockGitClient) AddWorktree(path, branch string) error { + args := m.Called(path, branch) + return args.Error(0) +} + +func (m *MockGitClient) AddWorktreeNewBranch(path, newBranch, baseBranch string) error { + args := m.Called(path, newBranch, baseBranch) + return args.Error(0) +} + +func (m *MockGitClient) AddWorktreeFromRemote(path, branch string) error { + args := m.Called(path, branch) + return args.Error(0) +} + +func (m *MockGitClient) RemoveWorktree(path string) error { + args := m.Called(path) + return args.Error(0) +} + +func (m *MockGitClient) ListWorktrees() ([]string, error) { + args := m.Called() + return args.Get(0).([]string), args.Error(1) +} + +// MockGitHubClient is a mock implementation of github.GitHubClient for testing +type MockGitHubClient struct { + mock.Mock +} + +func (m *MockGitHubClient) GetPRForBranch(branch string) (*github.PRInfo, error) { + args := m.Called(branch) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*github.PRInfo), args.Error(1) +} + +func (m *MockGitHubClient) GetAllPRs() (map[string]*github.PRInfo, error) { + args := m.Called() + return args.Get(0).(map[string]*github.PRInfo), args.Error(1) +} + +func (m *MockGitHubClient) UpdatePRBase(prNumber int, newBase string) error { + args := m.Called(prNumber, newBase) + return args.Error(0) +} + +func (m *MockGitHubClient) IsPRMerged(prNumber int) (bool, error) { + args := m.Called(prNumber) + return args.Bool(0), args.Error(1) +} + diff --git a/internal/testutil/setup.go b/internal/testutil/setup.go new file mode 100644 index 0000000..574305a --- /dev/null +++ b/internal/testutil/setup.go @@ -0,0 +1,16 @@ +package testutil + +import ( + "github.com/javoire/stackinator/internal/spinner" +) + +// SetupTest initializes test environment (disable spinners, etc.) +func SetupTest() { + spinner.Enabled = false +} + +// TeardownTest cleans up after tests +func TeardownTest() { + // Currently no cleanup needed, but keeping for future use +} +