Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/patch-fix-ghes-pr-creation.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 54 additions & 0 deletions pkg/cli/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,60 @@ func parseGitHubRepoSlugFromURL(url string) string {
return ""
}

// extractHostFromRemoteURL extracts the host (optionally including port) from a git remote URL.
// Supports HTTPS (https://host[:port]/path), HTTP (http://host[:port]/path), and SSH (git@host[:port]:path or ssh://git@host[:port]/path) formats.
// Returns the host portion as "host[:port]" when parsed, or "github.com" as the default if the URL cannot be parsed.
func extractHostFromRemoteURL(remoteURL string) string {
// HTTPS / HTTP format: https://host/path or http://host/path
for _, scheme := range []string{"https://", "http://"} {
if after, ok := strings.CutPrefix(remoteURL, scheme); ok {
if host, _, found := strings.Cut(after, "/"); found {
return host
}
return after
}
}

// SSH scp-like format: git@host:path
if after, ok := strings.CutPrefix(remoteURL, "git@"); ok {
if host, _, found := strings.Cut(after, ":"); found {
return host
}
}

// SSH URL format: ssh://git@host/path or ssh://host/path
if after, ok := strings.CutPrefix(remoteURL, "ssh://"); ok {
// Strip optional user info (e.g. "git@")
if _, userStripped, hasAt := strings.Cut(after, "@"); hasAt {
after = userStripped
}
if host, _, found := strings.Cut(after, "/"); found {
return host
}
return after
}

return "github.com"
}

// getHostFromOriginRemote returns the hostname of the git origin remote.
// For example, a remote URL of "https://ghes.example.com/org/repo.git" returns "ghes.example.com",
// and "git@github.com:owner/repo.git" returns "github.com".
// Returns "github.com" as the default if the remote URL cannot be determined.
func getHostFromOriginRemote() string {
cmd := exec.Command("git", "config", "--get", "remote.origin.url")
output, err := cmd.Output()
if err != nil {
gitLog.Printf("Failed to get remote origin URL: %v", err)
return "github.com"
}

remoteURL := strings.TrimSpace(string(output))
host := extractHostFromRemoteURL(remoteURL)
gitLog.Printf("Detected GitHub host from remote origin: %s", host)
return host
}

// getRepositorySlugFromRemote extracts the repository slug (owner/repo) from git remote URL
func getRepositorySlugFromRemote() string {
gitLog.Print("Getting repository slug from git remote")
Expand Down
127 changes: 127 additions & 0 deletions pkg/cli/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,130 @@ func TestCheckWorkflowFileStatusNotInRepo(t *testing.T) {
t.Error("Expected empty status when not in git repository")
}
}

func TestExtractHostFromRemoteURL(t *testing.T) {
tests := []struct {
name string
url string
expected string
}{
{
name: "public GitHub HTTPS",
url: "https://github.com/owner/repo.git",
expected: "github.com",
},
{
name: "public GitHub SSH scp-like",
url: "git@github.com:owner/repo.git",
expected: "github.com",
},
{
name: "GHES HTTPS",
url: "https://ghes.example.com/org/repo.git",
expected: "ghes.example.com",
},
{
name: "GHES SSH scp-like",
url: "git@ghes.example.com:org/repo.git",
expected: "ghes.example.com",
},
{
name: "GHES HTTPS without .git suffix",
url: "https://ghes.example.com/org/repo",
expected: "ghes.example.com",
},
{
name: "SSH URL format with user",
url: "ssh://git@ghes.example.com/org/repo.git",
expected: "ghes.example.com",
},
{
name: "SSH URL format without user",
url: "ssh://ghes.example.com/org/repo.git",
expected: "ghes.example.com",
},
{
name: "HTTP URL",
url: "http://ghes.example.com/org/repo.git",
expected: "ghes.example.com",
},
{
name: "empty URL defaults to github.com",
url: "",
expected: "github.com",
},
{
name: "unrecognized URL defaults to github.com",
url: "not-a-url",
expected: "github.com",
},
{
name: "GHES with port",
url: "https://ghes.example.com:8443/org/repo.git",
expected: "ghes.example.com:8443",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractHostFromRemoteURL(tt.url)
if got != tt.expected {
t.Errorf("extractHostFromRemoteURL(%q) = %q, want %q", tt.url, got, tt.expected)
}
})
}
}

func TestGetHostFromOriginRemote(t *testing.T) {
tmpDir := testutil.TempDir(t, "test-get-host-*")

originalDir, err := os.Getwd()
if err != nil {
t.Fatalf("Failed to get current directory: %v", err)
}
defer func() {
if err := os.Chdir(originalDir); err != nil {
t.Logf("Warning: failed to restore directory: %v", err)
}
}()

if err := os.Chdir(tmpDir); err != nil {
t.Fatalf("Failed to change to temp directory: %v", err)
}

// Initialize a git repo
if err := exec.Command("git", "init").Run(); err != nil {
t.Skip("Git not available")
}

t.Run("no remote defaults to github.com", func(t *testing.T) {
got := getHostFromOriginRemote()
if got != "github.com" {
t.Errorf("getHostFromOriginRemote() without remote = %q, want %q", got, "github.com")
}
})

t.Run("public GitHub remote", func(t *testing.T) {
if err := exec.Command("git", "remote", "add", "origin", "https://github.com/owner/repo.git").Run(); err != nil {
t.Fatalf("Failed to add remote: %v", err)
}
defer func() { _ = exec.Command("git", "remote", "remove", "origin").Run() }()

got := getHostFromOriginRemote()
if got != "github.com" {
t.Errorf("getHostFromOriginRemote() = %q, want %q", got, "github.com")
}
})

t.Run("GHES remote", func(t *testing.T) {
if err := exec.Command("git", "remote", "add", "origin", "https://ghes.example.com/org/repo.git").Run(); err != nil {
t.Fatalf("Failed to add remote: %v", err)
}
defer func() { _ = exec.Command("git", "remote", "remove", "origin").Run() }()

got := getHostFromOriginRemote()
if got != "ghes.example.com" {
t.Errorf("getHostFromOriginRemote() = %q, want %q", got, "ghes.example.com")
}
})
}
23 changes: 20 additions & 3 deletions pkg/cli/pr_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,18 @@ func createPR(branchName, title, body string, verbose bool) (int, string, error)
fmt.Fprintln(os.Stderr, console.FormatProgressMessage("Creating PR: "+title))
}

// Detect the GitHub host from the git remote so that GitHub Enterprise Server
// repositories are targeted correctly instead of defaulting to github.com.
remoteHost := getHostFromOriginRemote()

// Build gh repo view args, adding --hostname for GHES instances.
repoViewArgs := []string{"repo", "view", "--json", "owner,name"}
if remoteHost != "github.com" {
repoViewArgs = append(repoViewArgs, "--hostname", remoteHost)
}

// Get the current repository info to ensure PR is created in the correct repo
repoOutput, err := workflow.RunGH("Fetching repository info...", "repo", "view", "--json", "owner,name")
repoOutput, err := workflow.RunGH("Fetching repository info...", repoViewArgs...)
if err != nil {
return 0, "", fmt.Errorf("failed to get current repository info: %w", err)
}
Expand All @@ -786,8 +796,15 @@ func createPR(branchName, title, body string, verbose bool) (int, string, error)

repoSpec := fmt.Sprintf("%s/%s", repoInfo.Owner.Login, repoInfo.Name)

// Explicitly specify the repository to ensure PR is created in the current repo (not upstream)
output, err := workflow.RunGH("Creating pull request...", "pr", "create", "--repo", repoSpec, "--title", title, "--body", body, "--head", branchName)
// Build gh pr create args. Explicitly specifying --repo ensures the PR is created in the
// current repo (not an upstream fork). For GHES instances, --hostname routes the request
// to the correct GitHub Enterprise host instead of defaulting to github.com.
prCreateArgs := []string{"pr", "create", "--repo", repoSpec, "--title", title, "--body", body, "--head", branchName}
if remoteHost != "github.com" {
prCreateArgs = append(prCreateArgs, "--hostname", remoteHost)
}

output, err := workflow.RunGH("Creating pull request...", prCreateArgs...)
if err != nil {
// Try to get stderr for better error reporting
var exitError *exec.ExitError
Expand Down
Loading