Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 82 additions & 9 deletions webhook_server/libs/handlers/check_run_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any

from github.CheckRun import CheckRun
from github.CommitStatus import CommitStatus
from github.PullRequest import PullRequest
from github.Repository import Repository

Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, github_webhook: "GithubWebhook", owners_file_handler: OwnersF
self.repository: Repository = self.github_webhook.repository
self._repository_private: bool | None = None
self._branch_required_status_checks: list[str] | None = None
self._all_required_status_checks: list[str] | None = None
if isinstance(self.owners_file_handler, OwnersFileHandler):
self.labels_handler = LabelsHandler(
github_webhook=self.github_webhook, owners_file_handler=self.owners_file_handler
Expand Down Expand Up @@ -360,26 +362,84 @@ async def is_check_run_in_progress(self, check_run: str) -> bool:
return False

async def required_check_failed_or_no_status(
self, pull_request: PullRequest, last_commit_check_runs: list[CheckRun], check_runs_in_progress: list[str]
self,
pull_request: PullRequest,
last_commit_check_runs: list[CheckRun],
last_commit_statuses: list[CommitStatus],
check_runs_in_progress: list[str],
) -> str:
failed_check_runs: list[str] = []
no_status_check_runs: list[str] = []

# Find required checks that are missing entirely from check runs list
required_checks = set(await self.all_required_status_checks(pull_request=pull_request))
existing_check_names = {cr.name for cr in last_commit_check_runs}
missing_required_checks = required_checks - existing_check_names

# Add missing checks to no_status list (these haven't been created yet)
no_status_check_runs: list[str] = list(missing_required_checks)

# Add commit statuses (legacy API) to existing checks
status_check_names = {status.context for status in last_commit_statuses}
existing_check_names = existing_check_names | status_check_names

# Recalculate missing checks after adding statuses
missing_required_checks = required_checks - existing_check_names
no_status_check_runs = list(missing_required_checks)

# Check commit statuses for failures/pending
self.logger.debug(f"{self.log_prefix} Status details: {[(s.context, s.state) for s in last_commit_statuses]}")

# Filter to latest status per context (highest ID = most recent)
status_by_context: dict[str, CommitStatus] = {}
for status in last_commit_statuses:
if status.context not in status_by_context or status.id > status_by_context[status.context].id:
status_by_context[status.context] = status

latest_statuses = list(status_by_context.values())
self.logger.debug(
f"{self.log_prefix} Filtered {len(last_commit_statuses)} statuses to {len(latest_statuses)} latest statuses"
)

for status in latest_statuses:
if status.context not in required_checks:
continue # Not a required check

if status.state == "success":
continue # Passed

if status.state == "pending":
# Skip if already marked as in-progress (to avoid duplicate reporting)
if status.context in check_runs_in_progress:
continue
if status.context not in no_status_check_runs:
no_status_check_runs.append(status.context)
elif status.state in ("failure", "error"):
if status.context not in failed_check_runs:
failed_check_runs.append(status.context)

for check_run in last_commit_check_runs:
self.logger.debug(f"{self.log_prefix} Check if {check_run.name} failed or do not have status.")
# Skip check runs that have a corresponding success status
status_contexts = {status.context for status in latest_statuses if status.state == "success"}
if check_run.name in status_contexts:
continue

if (
check_run.name == CAN_BE_MERGED_STR
or check_run.conclusion == SUCCESS_STR
or check_run.name not in await self.all_required_status_checks(pull_request=pull_request)
):
self.logger.debug(f"{self.log_prefix} {check_run.name} is success or not required, skipping.")
continue

if check_run.conclusion is None:
no_status_check_runs.append(check_run.name)
if check_run.name not in no_status_check_runs:
no_status_check_runs.append(check_run.name)

else:
failed_check_runs.append(check_run.name)
if check_run.name not in failed_check_runs:
failed_check_runs.append(check_run.name)

self.logger.debug(f"{self.log_prefix} no_status_check_runs after processing check runs: {no_status_check_runs}")
self.logger.debug(f"{self.log_prefix} failed_check_runs after processing check runs: {failed_check_runs}")

msg = ""

Expand All @@ -399,6 +459,10 @@ async def required_check_failed_or_no_status(
return msg

async def all_required_status_checks(self, pull_request: PullRequest) -> list[str]:
# Cache to avoid repeated processing
if self._all_required_status_checks is not None:
return self._all_required_status_checks

all_required_status_checks: list[str] = []
branch_required_status_checks = await self.get_branch_required_status_checks(pull_request=pull_request)

Expand All @@ -419,12 +483,13 @@ async def all_required_status_checks(self, pull_request: PullRequest) -> list[st

_all_required_status_checks = branch_required_status_checks + all_required_status_checks
self.logger.debug(f"{self.log_prefix} All required status checks: {_all_required_status_checks}")
self._all_required_status_checks = _all_required_status_checks
return _all_required_status_checks

async def get_branch_required_status_checks(self, pull_request: PullRequest) -> list[str]:
# Check if private repo first (cache to avoid repeated API calls)
if self._repository_private is None:
self._repository_private = self.repository.private
self._repository_private = await asyncio.to_thread(lambda: self.repository.private)

if self._repository_private:
self.logger.info(
Expand All @@ -438,13 +503,17 @@ async def get_branch_required_status_checks(self, pull_request: PullRequest) ->

pull_request_branch = await asyncio.to_thread(self.repository.get_branch, pull_request.base.ref)
branch_protection = await asyncio.to_thread(pull_request_branch.get_protection)
branch_required_status_checks = branch_protection.required_status_checks.contexts
branch_required_status_checks = await asyncio.to_thread(
lambda: branch_protection.required_status_checks.contexts
)
self.logger.debug(f"{self.log_prefix} branch_required_status_checks: {branch_required_status_checks}")
self._branch_required_status_checks = branch_required_status_checks
return self._branch_required_status_checks

async def required_check_in_progress(
self, pull_request: PullRequest, last_commit_check_runs: list[CheckRun]
self,
pull_request: PullRequest,
last_commit_check_runs: list[CheckRun],
) -> tuple[str, list[str]]:
self.logger.debug(f"{self.log_prefix} Check if any required check runs in progress.")

Expand All @@ -455,6 +524,10 @@ async def required_check_in_progress(
and check_run.name != CAN_BE_MERGED_STR
and check_run.name in await self.all_required_status_checks(pull_request=pull_request)
]

# Note: Status API doesn't have an "in_progress" state - only pending (queued),
# success, failure, and error. We only check Check Runs for in-progress status.

if check_runs_in_progress:
self.logger.debug(
f"{self.log_prefix} Some required check runs in progress {check_runs_in_progress}, "
Expand Down
20 changes: 17 additions & 3 deletions webhook_server/libs/handlers/pull_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,8 +956,20 @@ async def check_if_can_be_merged(self, pull_request: PullRequest) -> None:
try:
self.logger.info(f"{self.log_prefix} Check if {CAN_BE_MERGED_STR}.")
await self.check_run_handler.set_merge_check_in_progress()
_last_commit_check_runs = await asyncio.to_thread(self.github_webhook.last_commit.get_check_runs)
last_commit_check_runs = list(_last_commit_check_runs)
# Fetch check runs and statuses in parallel (2 API calls → 1 concurrent operation)
_check_runs, _statuses = await asyncio.gather(
asyncio.to_thread(lambda: list(self.github_webhook.last_commit.get_check_runs())),
asyncio.to_thread(lambda: list(self.github_webhook.last_commit.get_statuses())),
)
last_commit_check_runs = _check_runs
last_commit_statuses = _statuses
self.logger.debug(
f"{self.log_prefix} Fetched {len(last_commit_check_runs)} check runs "
f"and {len(last_commit_statuses)} statuses"
)
if last_commit_statuses:
status_names = [s.context for s in last_commit_statuses]
self.logger.debug(f"{self.log_prefix} Commit statuses: {status_names}")
_labels = await self.labels_handler.pull_request_labels_names(pull_request=pull_request)
self.logger.debug(f"{self.log_prefix} check if can be merged. PR labels are: {_labels}")

Expand All @@ -970,7 +982,8 @@ async def check_if_can_be_merged(self, pull_request: PullRequest) -> None:
required_check_in_progress_failure_output,
check_runs_in_progress,
) = await self.check_run_handler.required_check_in_progress(
pull_request=pull_request, last_commit_check_runs=last_commit_check_runs
pull_request=pull_request,
last_commit_check_runs=last_commit_check_runs,
)
if required_check_in_progress_failure_output:
failure_output += required_check_in_progress_failure_output
Expand All @@ -984,6 +997,7 @@ async def check_if_can_be_merged(self, pull_request: PullRequest) -> None:
required_check_failed_failure_output = await self.check_run_handler.required_check_failed_or_no_status(
pull_request=pull_request,
last_commit_check_runs=last_commit_check_runs,
last_commit_statuses=last_commit_statuses,
check_runs_in_progress=check_runs_in_progress,
)
if required_check_failed_failure_output:
Expand Down
Loading