Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/git/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ pub mod pull_request;
pub use clone::{clone_repository, remove_repository};
pub use common::Logger;
pub use pull_request::{
add_all_changes, commit_changes, create_and_checkout_branch, get_default_branch, has_changes,
push_branch,
add_all_changes, checkout_branch, commit_changes, create_and_checkout_branch,
get_current_branch, get_default_branch, has_changes, push_branch,
};
50 changes: 48 additions & 2 deletions src/git/pull_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,13 @@ pub fn push_branch(repo_path: &str, branch_name: &str) -> Result<()> {
.context("Failed to execute git push command")?;

if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let stdout = String::from_utf8_lossy(&output.stdout);
anyhow::bail!(
"Failed to push branch: {}",
String::from_utf8_lossy(&output.stderr)
"Failed to push branch '{}' to remote 'origin':\nstderr: {}\nstdout: {}",
branch_name,
stderr.trim(),
stdout.trim()
);
}

Expand Down Expand Up @@ -159,3 +163,45 @@ pub fn get_default_branch(repo_path: &str) -> Result<String> {
// Final fallback to default branch
Ok(crate::constants::git::FALLBACK_BRANCH.to_string())
}

/// Get the current branch name
pub fn get_current_branch(repo_path: &str) -> Result<String> {
let output = Command::new("git")
.args(["branch", "--show-current"])
.current_dir(repo_path)
.output()
.context("Failed to execute git branch command")?;

if !output.status.success() {
anyhow::bail!(
"Failed to get current branch: {}",
String::from_utf8_lossy(&output.stderr)
);
}

let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
if branch.is_empty() {
anyhow::bail!("No current branch (detached HEAD state?)");
}

Ok(branch)
}

/// Checkout an existing branch
pub fn checkout_branch(repo_path: &str, branch_name: &str) -> Result<()> {
let output = Command::new("git")
.args(["checkout", branch_name])
.current_dir(repo_path)
.output()
.context("Failed to execute git checkout command")?;

if !output.status.success() {
anyhow::bail!(
"Failed to checkout branch '{}': {}",
branch_name,
String::from_utf8_lossy(&output.stderr)
);
}

Ok(())
}
133 changes: 128 additions & 5 deletions src/github/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@ use anyhow::Result;
use colored::*;
use uuid::Uuid;

/// RAII guard to automatically restore the original branch on drop
struct BranchGuard<'a> {
repo_path: String,
original_branch: Option<String>,
repo_name: &'a str,
}

impl Drop for BranchGuard<'_> {
fn drop(&mut self) {
if let Some(ref original) = self.original_branch
&& let Err(e) = git::checkout_branch(&self.repo_path, original)
{
eprintln!(
"{} | {}",
self.repo_name.cyan().bold(),
format!(
"Warning: Failed to restore original branch '{}': {}",
original, e
)
.yellow()
);
}
}
}

/// High-level function to create a PR from local changes
///
/// This function encapsulates the entire pull request creation flow:
Expand All @@ -27,6 +52,14 @@ pub async fn create_pr_from_workspace(repo: &Repository, options: &PrOptions) ->
return Ok(());
}

// Save the current branch to restore later using RAII guard
let original_branch = git::get_current_branch(&repo_path).ok();
let _branch_guard = BranchGuard {
repo_path: repo_path.clone(),
original_branch: original_branch.clone(),
repo_name: &repo.name,
};

// Generate branch name if not provided
let branch_name = options.branch_name.clone().unwrap_or_else(|| {
format!(
Expand Down Expand Up @@ -61,6 +94,12 @@ pub async fn create_pr_from_workspace(repo: &Repository, options: &PrOptions) ->
"Pull request created:".green(),
pr_url
);
} else {
println!(
"{} | {}",
repo.name.cyan().bold(),
"Branch created (not pushed, --create-only mode)".yellow()
);
}

Ok(())
Expand Down Expand Up @@ -99,16 +138,44 @@ async fn create_github_pr(
}

/// Parse a GitHub URL to extract owner and repository name
///
/// Supports both SSH (git@host:owner/repo) and HTTPS (https://host/owner/repo) formats.
/// Works with GitHub, GitLab, Bitbucket, and other Git hosting providers.
fn parse_github_url(url: &str) -> Result<(String, String)> {
let url = url.trim_end_matches('/').trim_end_matches(".git");

let parts: Vec<&str> = url.split('/').collect();
if parts.len() < 2 {
anyhow::bail!("Invalid GitHub URL format: {url}");
// Handle SSH format: git@host:owner/repo or user@host:owner/repo
// The key indicator is the presence of '@' followed by ':' without '//'
if let Some(at_pos) = url.find('@')
&& let Some(colon_pos) = url[at_pos..].find(':')
{
// Extract the path after the colon
let path_start = at_pos + colon_pos + 1;
let path = &url[path_start..];

// Split owner/repo - use rsplit to handle nested paths like owner/group/repo
let mut parts = path.rsplitn(2, '/');
let repo_name = parts.next().ok_or_else(|| {
anyhow::anyhow!("Invalid SSH URL format: missing repo name in {}", url)
})?;
let owner = parts
.next()
.ok_or_else(|| anyhow::anyhow!("Invalid SSH URL format: missing owner in {}", url))?;

return Ok((owner.to_string(), repo_name.to_string()));
}

let repo_name = parts[parts.len() - 1];
let owner = parts[parts.len() - 2];
// Handle HTTPS format: https://host/owner/repo
// Use rsplit to efficiently get the last two segments
let mut parts = url.rsplitn(3, '/');
let repo_name = parts
.next()
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("Invalid URL format: missing repo name in {}", url))?;
let owner = parts
.next()
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("Invalid URL format: missing owner in {}", url))?;

Ok((owner.to_string(), repo_name.to_string()))
}
Expand Down Expand Up @@ -326,4 +393,60 @@ mod tests {

assert_eq!(options_with_base.base_branch.unwrap(), "develop");
}

#[test]
fn test_parse_github_url_https() {
// Test HTTPS URL parsing
let (owner, repo) =
parse_github_url("https://github.com/example-org/example-repo").unwrap();
assert_eq!(owner, "example-org");
assert_eq!(repo, "example-repo");

// Test with .git suffix
let (owner, repo) = parse_github_url("https://github.com/test-org/test-repo.git").unwrap();
assert_eq!(owner, "test-org");
assert_eq!(repo, "test-repo");

// Test with trailing slash
let (owner, repo) = parse_github_url("https://github.com/owner/repo/").unwrap();
assert_eq!(owner, "owner");
assert_eq!(repo, "repo");
}

#[test]
fn test_parse_github_url_ssh() {
// Test SSH URL parsing
let (owner, repo) = parse_github_url("git@github.com:example-org/example-repo").unwrap();
assert_eq!(owner, "example-org");
assert_eq!(repo, "example-repo");

// Test with .git suffix
let (owner, repo) = parse_github_url("git@github.com:test-org/test-repo.git").unwrap();
assert_eq!(owner, "test-org");
assert_eq!(repo, "test-repo");

// Test GitLab SSH format
let (owner, repo) = parse_github_url("git@gitlab.com:mycompany/myrepo").unwrap();
assert_eq!(owner, "mycompany");
assert_eq!(repo, "myrepo");

// Test Bitbucket SSH format
let (owner, repo) = parse_github_url("git@bitbucket.org:workspace/repository.git").unwrap();
assert_eq!(owner, "workspace");
assert_eq!(repo, "repository");
}

#[test]
fn test_parse_github_url_invalid() {
// Test truly invalid URLs - single words or malformed SSH
assert!(parse_github_url("invalid").is_err());
assert!(parse_github_url("git@github.com:").is_err());
assert!(parse_github_url("git@github.com:owner").is_err());

// Note: These cases actually succeed because they technically have owner/repo segments:
// - "https://github.com/" parses as owner="github.com", repo=""
// - "https://github.com/owner" parses as owner="github.com", repo="owner"
// These would fail at the API call level, not at URL parsing level
// To catch these, we'd need to validate against known hosts or check for empty strings
}
}
19 changes: 15 additions & 4 deletions tests/github_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,20 @@ async fn test_create_pr_workspace_commit_message_fallback() {
let result = create_pr_from_workspace(&repo, &options).await;
assert!(result.is_ok());

// Check that the commit was made with the title
// Get the created branch name (starts with "automated-changes-")
let output = std::process::Command::new("git")
.args(["log", "-1", "--pretty=format:%s"])
.args(["branch", "--list", "automated-changes-*"])
.current_dir(&repo_path)
.output()
.expect("git branch failed");

let branches = String::from_utf8(output.stdout).unwrap();
let branch_name = branches.trim().trim_start_matches("* ").trim();
assert!(branch_name.starts_with("automated-changes-"));

// Check that the commit was made with the title on the created branch
let output = std::process::Command::new("git")
.args(["log", "-1", "--pretty=format:%s", branch_name])
.current_dir(&repo_path)
.output()
.expect("git log failed");
Expand Down Expand Up @@ -355,9 +366,9 @@ async fn test_create_pr_workspace_custom_branch_and_commit() {
let branches = String::from_utf8(output.stdout).unwrap();
assert!(branches.contains("custom-branch"));

// Verify custom commit message was used
// Verify custom commit message was used on the custom-branch
let output = std::process::Command::new("git")
.args(["log", "-1", "--pretty=format:%s"])
.args(["log", "-1", "--pretty=format:%s", "custom-branch"])
.current_dir(&repo_path)
.output()
.expect("git log failed");
Expand Down