Skip to content
135 changes: 123 additions & 12 deletions webhook_server/libs/handlers/pull_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,21 +732,132 @@ async def remove_labels_when_pull_request_sync(self, pull_request: PullRequest)
if isinstance(result, Exception):
self.logger.error(f"{self.log_prefix} Async task failed: {result}")

async def _compare_branches(self, base_ref: str, head_ref_full: str) -> dict[str, Any] | None:
"""Call GitHub Compare API to get branch comparison data for rebase detection.

This API is used ONLY for detecting if a PR is behind/diverged from base branch.
It does NOT provide conflict information - use pull_request.mergeable for conflicts.

Args:
base_ref: Base branch reference (e.g., "main").
head_ref_full: Full head reference including owner (e.g., "user:branch").

Returns:
Compare API response data or None if API call fails.

Compare API Reference:
GET /repos/{owner}/{repo}/compare/{base}...{head}
Response fields used:
- behind_by: int - commits behind base branch
- status: str - "ahead", "behind", "diverged", "identical"

NOTE: This API does NOT return conflict information (mergeable/mergeable_state).
"""
try:
_headers, data = await asyncio.to_thread(
self.repository._requester.requestJsonAndCheck,
"GET",
f"{self.repository.url}/compare/{base_ref}...{head_ref_full}",
)
return data
except GithubException:
self.logger.exception(f"{self.log_prefix} Failed to call Compare API for {base_ref}...{head_ref_full}")
return None
except Exception:
self.logger.exception(f"{self.log_prefix} Unexpected error calling Compare API")
return None

async def label_pull_request_by_merge_state(self, pull_request: PullRequest) -> None:
merge_state = await asyncio.to_thread(lambda: pull_request.mergeable_state)
self.logger.debug(f"{self.log_prefix} Mergeable state is {merge_state}")
if merge_state == "unknown":
return
"""Label pull request based on merge state.

if merge_state == "behind":
await self.labels_handler._add_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR)
else:
await self.labels_handler._remove_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR)
Simple flow:
1. Check pull_request.mergeable for conflicts
2. If has conflicts → add has-conflicts, exit
3. Else → remove has-conflicts, check Compare API for rebase status

if merge_state == "dirty":
await self.labels_handler._add_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)
else:
await self.labels_handler._remove_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)
Uses both GitHub APIs for accurate labeling:
- has-conflicts: pull_request.mergeable == False (true merge conflict detection)
- needs-rebase: Compare API behind_by > 0 or status == "diverged"

Both labels can coexist - they both reflect the actual PR state.

Args:
pull_request: The GitHub pull request object to label.
"""
if self.ctx:
self.ctx.start_step("label_merge_state")

try:
# Get current labels (single API call for optimization)
current_labels = await self.labels_handler.pull_request_labels_names(pull_request=pull_request)
has_conflicts_label_exists = HAS_CONFLICTS_LABEL_STR in current_labels
needs_rebase_label_exists = NEEDS_REBASE_LABEL_STR in current_labels

# Step 1: Check for conflicts first
mergeable = await asyncio.to_thread(lambda: pull_request.mergeable)
has_conflicts = mergeable is False

if has_conflicts:
# Has conflicts - add has-conflicts label and exit
self.logger.debug(f"{self.log_prefix} PR has conflicts. {mergeable=}")

if not has_conflicts_label_exists:
self.logger.debug(f"{self.log_prefix} Adding {HAS_CONFLICTS_LABEL_STR} label")
await self.labels_handler._add_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)

if self.ctx:
self.ctx.complete_step("label_merge_state", has_conflicts=True, needs_rebase=False)
return # Exit early - conflicts take precedence

# Step 2: No conflicts - remove has-conflicts label if present
if has_conflicts_label_exists:
self.logger.debug(f"{self.log_prefix} Removing {HAS_CONFLICTS_LABEL_STR} label")
await self.labels_handler._remove_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)

# Step 3: Check if needs rebase via Compare API
base_ref, head_user_login, head_ref = await asyncio.gather(
asyncio.to_thread(lambda: pull_request.base.ref),
asyncio.to_thread(lambda: pull_request.head.user.login),
asyncio.to_thread(lambda: pull_request.head.ref),
)
head_ref_full = f"{head_user_login}:{head_ref}"

compare_data = await self._compare_branches(base_ref=base_ref, head_ref_full=head_ref_full)
if compare_data is None:
self.logger.warning(f"{self.log_prefix} Compare API failed, skipping rebase label update")
if self.ctx:
self.ctx.complete_step("label_merge_state", compare_api_failed=True)
return

behind_by = compare_data.get("behind_by", 0)
status = compare_data.get("status", "")

needs_rebase = behind_by > 0 or status == "diverged"

self.logger.debug(
f"{self.log_prefix} Compare API - behind_by: {behind_by}, "
f"status: {status}, needs_rebase: {needs_rebase}"
)

# Step 4: Update needs-rebase label
if needs_rebase and not needs_rebase_label_exists:
self.logger.debug(f"{self.log_prefix} Adding {NEEDS_REBASE_LABEL_STR} label")
await self.labels_handler._add_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR)
elif not needs_rebase and needs_rebase_label_exists:
self.logger.debug(f"{self.log_prefix} Removing {NEEDS_REBASE_LABEL_STR} label")
await self.labels_handler._remove_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR)

if self.ctx:
self.ctx.complete_step("label_merge_state", has_conflicts=False, needs_rebase=needs_rebase)

except asyncio.CancelledError:
self.logger.debug(f"{self.log_prefix} Label merge state check cancelled")
raise
except Exception as ex:
self.logger.exception(f"{self.log_prefix} Failed to label merge state")
if self.ctx:
self.ctx.fail_step("label_merge_state", ex, traceback.format_exc())
raise

async def _process_verified_for_update_or_new_pull_request(self, pull_request: PullRequest) -> None:
if not self.github_webhook.verified_job:
Expand Down
Loading