From 94f2b861387d012c353805e40c2898ab537c6784 Mon Sep 17 00:00:00 2001 From: Meni Yakove Date: Tue, 14 Apr 2026 10:03:55 +0000 Subject: [PATCH] feat: add retry with exponential backoff for GitHub API calls --- CLAUDE.md | 55 ++-- webhook_server/libs/github_api.py | 161 ++++++++-- .../libs/handlers/check_run_handler.py | 50 ++- .../libs/handlers/issue_comment_handler.py | 119 +++++-- .../libs/handlers/labels_handler.py | 29 +- .../libs/handlers/owners_files_handler.py | 51 +-- .../libs/handlers/pull_request_handler.py | 194 ++++++++--- webhook_server/libs/handlers/push_handler.py | 6 +- .../libs/handlers/runner_handler.py | 128 ++++++-- webhook_server/libs/test_oracle.py | 10 +- .../tests/test_clean_rebase_detection.py | 2 +- webhook_server/tests/test_github_retry.py | 300 ++++++++++++++++++ webhook_server/tests/test_runner_handler.py | 11 +- webhook_server/tests/test_test_oracle.py | 4 +- webhook_server/utils/github_retry.py | 128 ++++++++ 15 files changed, 1038 insertions(+), 210 deletions(-) create mode 100644 webhook_server/tests/test_github_retry.py create mode 100644 webhook_server/utils/github_retry.py diff --git a/CLAUDE.md b/CLAUDE.md index 460417e2..974d6765 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -166,7 +166,7 @@ def user(self): 1. **Repository Data** - `repository_data` ALWAYS set before handlers instantiate 2. **Webhook User Objects** - `user.node_id`, `user.type`, `sender` always exist -3. **PyGithub REST API** - **🔴 CRITICAL:** PyGithub is blocking - **MUST** wrap with `asyncio.to_thread()` +3. **PyGithub REST API** - **🔴 CRITICAL:** PyGithub is blocking - **MUST** wrap with `github_api_call()` (provides retry with exponential backoff) --- @@ -220,7 +220,7 @@ FastAPI-based GitHub webhook server that automates repository management and pul - `webhook_server/libs/github_api.py` provides core `GithubWebhook` class - Uses PyGithub (REST API v3) for all GitHub operations -- **🔴 CRITICAL:** PyGithub is synchronous/blocking - **MUST** wrap with `asyncio.to_thread()` +- **🔴 CRITICAL:** PyGithub is synchronous/blocking - **MUST** wrap with `github_api_call()` from `webhook_server.utils.github_retry` (retries transient 500/502/503/504 errors with exponential backoff) - Supports multiple GitHub tokens with automatic failover **Log Viewer System:** @@ -290,9 +290,9 @@ class SomeHandler: # Log results ``` -### 🔴 MANDATORY: Non-Blocking PyGithub Operations +### 🔴 MANDATORY: Non-Blocking PyGithub Operations with Retry -**CRITICAL:** PyGithub is synchronous - ALL operations MUST use `asyncio.to_thread()` +**CRITICAL:** PyGithub is synchronous - ALL operations MUST use `github_api_call()` from `webhook_server.utils.github_retry`. This wraps `asyncio.to_thread()` with retry logic for transient GitHub API failures (HTTP 500/502/503/504). #### What Blocks the Event Loop @@ -316,34 +316,38 @@ class SomeHandler: ```python import asyncio -from github.PullRequest import PullRequest -# ✅ CORRECT - Wrap ALL method calls -await asyncio.to_thread(pull_request.create_issue_comment, "Comment") -await asyncio.to_thread(pull_request.add_to_labels, "label") -await asyncio.to_thread(repository.get_pull, number) +from webhook_server.utils.github_retry import github_api_call + +# ✅ CORRECT - Wrap ALL method calls with github_api_call (includes retry) +await github_api_call(pull_request.create_issue_comment, "Comment", logger=self.logger, log_prefix=self.log_prefix) +await github_api_call(pull_request.add_to_labels, "label", logger=self.logger, log_prefix=self.log_prefix) +await github_api_call(repository.get_pull, number, logger=self.logger, log_prefix=self.log_prefix) # ✅ CORRECT - Wrap ALL property accesses that may trigger API calls -is_draft = await asyncio.to_thread(lambda: pull_request.draft) -mergeable = await asyncio.to_thread(lambda: pull_request.mergeable) -labels = await asyncio.to_thread(lambda: list(pull_request.labels)) +is_draft = await github_api_call(lambda: pull_request.draft, logger=self.logger, log_prefix=self.log_prefix) +mergeable = await github_api_call(lambda: pull_request.mergeable, logger=self.logger, log_prefix=self.log_prefix) +labels = await github_api_call(lambda: list(pull_request.labels), logger=self.logger, log_prefix=self.log_prefix) # ✅ CORRECT - Wrap PaginatedList iteration -commits = await asyncio.to_thread(lambda: list(pull_request.get_commits())) +commits = await github_api_call(lambda: list(pull_request.get_commits()), logger=self.logger, log_prefix=self.log_prefix) for commit in commits: await process_commit(commit) # ✅ CORRECT - Concurrent operations is_draft, mergeable, state = await asyncio.gather( - asyncio.to_thread(lambda: pull_request.draft), - asyncio.to_thread(lambda: pull_request.mergeable), - asyncio.to_thread(lambda: pull_request.state), + github_api_call(lambda: pull_request.draft, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.mergeable, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.state, logger=self.logger, log_prefix=self.log_prefix), ) # ❌ WRONG - NEVER call PyGithub directly pull_request.create_issue_comment("Comment") # BLOCKS! is_draft = pull_request.draft # BLOCKS! for commit in pull_request.get_commits(): ... # BLOCKS! + +# ❌ WRONG - NEVER use raw asyncio.to_thread (no retry protection) +await asyncio.to_thread(pull_request.create_issue_comment, "Comment") # NO RETRY! ``` #### Decision Tree @@ -351,17 +355,17 @@ for commit in pull_request.get_commits(): ... # BLOCKS! Before accessing ANY PyGithub object: 1. Is this a PyGithub object? → YES, it may block -2. Calling a method? → **DEFINITELY BLOCKS** - wrap in `asyncio.to_thread()` -3. Accessing a property? → **MAY BLOCK** - wrap in `asyncio.to_thread(lambda: obj.property)` -4. Iterating PaginatedList? → **BLOCKS** - wrap in `asyncio.to_thread(lambda: list(...))` +2. Calling a method? → **DEFINITELY BLOCKS** - wrap in `github_api_call()` +3. Accessing a property? → **MAY BLOCK** - wrap in `github_api_call(lambda: obj.property, logger=self.logger, log_prefix=self.log_prefix)` +4. Iterating PaginatedList? → **BLOCKS** - wrap in `github_api_call(lambda: list(...), logger=self.logger, log_prefix=self.log_prefix)` 5. Webhook payload attribute? → Usually safe (`.number`, `.title`) -6. **Unsure? ALWAYS wrap in `asyncio.to_thread()`** +6. **Unsure? ALWAYS wrap in `github_api_call()`** **Why this is critical:** - PyGithub is synchronous - each operation blocks 100ms-2 seconds - Blocking = frozen server (no other webhooks processed) -- `asyncio.to_thread()` runs code in thread pool, keeps event loop responsive +- `github_api_call()` runs code in thread pool via `asyncio.to_thread()`, keeps event loop responsive, and retries on transient GitHub API failures - **NOT OPTIONAL** - required for correct async operation **Impact of blocking:** @@ -376,13 +380,13 @@ Before accessing ANY PyGithub object: ```python async def add_pr_comment(self, pull_request: PullRequest, body: str) -> None: - await asyncio.to_thread(pull_request.create_issue_comment, body) + await github_api_call(pull_request.create_issue_comment, body, logger=self.logger, log_prefix=self.log_prefix) async def check_pr_status(self, pull_request: PullRequest) -> tuple[bool, bool, str]: return await asyncio.gather( - asyncio.to_thread(lambda: pull_request.draft), - asyncio.to_thread(lambda: pull_request.mergeable), - asyncio.to_thread(lambda: pull_request.state), + github_api_call(lambda: pull_request.draft, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.mergeable, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.state, logger=self.logger, log_prefix=self.log_prefix), ) ``` @@ -605,6 +609,7 @@ mock_api = AsyncMock() mock_api.get_pull_request.return_value = mock_pr_data with patch("asyncio.to_thread", side_effect=mock_to_thread): + # Note: Tests patch asyncio.to_thread since github_api_call delegates to it internally result = await unified_api.get_pr_for_check_runs(owner, repo, number) ``` diff --git a/webhook_server/libs/github_api.py b/webhook_server/libs/github_api.py index 2aa76044..1d3c41c6 100644 --- a/webhook_server/libs/github_api.py +++ b/webhook_server/libs/github_api.py @@ -46,6 +46,7 @@ DEFAULT_BRANCH_PROTECTION, get_repository_github_app_api, ) +from webhook_server.utils.github_retry import github_api_call from webhook_server.utils.helpers import ( _redact_secrets, get_api_with_highest_rate_limit, @@ -233,7 +234,9 @@ async def _get_token_metrics(self) -> str: f"remaining: {remaining})" ) - final_rate_limit = await asyncio.to_thread(self.github_api.get_rate_limit) + final_rate_limit = await github_api_call( + self.github_api.get_rate_limit, logger=self.logger, log_prefix=self.log_prefix + ) final_remaining = final_rate_limit.rate.remaining # Fallback to global rate limit calculation (inaccurate under concurrency) @@ -295,7 +298,9 @@ async def _clone_repository( try: github_token = self.token - clone_url = await asyncio.to_thread(lambda: self.repository.clone_url) + clone_url = await github_api_call( + lambda: self.repository.clone_url, logger=self.logger, log_prefix=self.log_prefix + ) clone_url_with_token = clone_url.replace("https://", f"https://{github_token}@") rc, _, err = await run_command( @@ -315,7 +320,9 @@ def redact_output(value: str) -> str: # Configure git user git_cmd = f"git -C {self.clone_repo_dir}" - owner_login = await asyncio.to_thread(lambda: self.repository.owner.login) + owner_login = await github_api_call( + lambda: self.repository.owner.login, logger=self.logger, log_prefix=self.log_prefix + ) rc, _, _ = await run_command( command=f"{git_cmd} config user.name '{owner_login}'", log_prefix=self.log_prefix, @@ -335,7 +342,9 @@ def redact_output(value: str) -> str: # Fetch only what's needed instead of all refs if pull_request: # Fetch the base branch first (needed for checkout) - base_ref = await asyncio.to_thread(lambda: pull_request.base.ref) + base_ref = await github_api_call( + lambda: pull_request.base.ref, logger=self.logger, log_prefix=self.log_prefix + ) rc, _, err = await run_command( command=f"{git_cmd} fetch origin {base_ref}", log_prefix=self.log_prefix, @@ -347,7 +356,9 @@ def redact_output(value: str) -> str: raise RuntimeError(f"Failed to fetch base branch {base_ref}: {redacted_err}") # Fetch only this specific PR's ref - pr_number = await asyncio.to_thread(lambda: pull_request.number) + pr_number = await github_api_call( + lambda: pull_request.number, logger=self.logger, log_prefix=self.log_prefix + ) rc, _, err = await run_command( command=f"{git_cmd} fetch origin +refs/pull/{pr_number}/head:refs/remotes/origin/pr/{pr_number}", log_prefix=self.log_prefix, @@ -372,7 +383,9 @@ def redact_output(value: str) -> str: # Determine checkout target if pull_request: - checkout_target = await asyncio.to_thread(lambda: pull_request.base.ref) + checkout_target = await github_api_call( + lambda: pull_request.base.ref, logger=self.logger, log_prefix=self.log_prefix + ) else: # For push events (tags only - branch pushes skip cloning) # checkout_ref guaranteed to be non-None by validation at function start @@ -518,9 +531,15 @@ async def process(self) -> Any: if pull_request: # Update context with PR info if self.ctx: - pr_number = await asyncio.to_thread(lambda: pull_request.number) - pr_title = await asyncio.to_thread(lambda: pull_request.title) - pr_author = await asyncio.to_thread(lambda: pull_request.user.login) + pr_number = await github_api_call( + lambda: pull_request.number, logger=self.logger, log_prefix=self.log_prefix + ) + pr_title = await github_api_call( + lambda: pull_request.title, logger=self.logger, log_prefix=self.log_prefix + ) + pr_author = await github_api_call( + lambda: pull_request.user.login, logger=self.logger, log_prefix=self.log_prefix + ) self.ctx.pr_number = pr_number self.ctx.pr_title = pr_title self.ctx.pr_author = pr_author @@ -528,7 +547,7 @@ async def process(self) -> Any: self.log_prefix = self.prepare_log_prefix(pull_request=pull_request) self.logger.debug(f"{self.log_prefix} {event_log}") - if await asyncio.to_thread(lambda: pull_request.draft): + if await github_api_call(lambda: pull_request.draft, logger=self.logger, log_prefix=self.log_prefix): allow_commands_on_draft = self.config.get_value("allow-commands-on-draft-prs") # Validate type: must be a list, treat invalid types as None (default-deny) @@ -708,7 +727,9 @@ async def check_token(api: github.Github, token: str) -> str | None: """Check a single API token and return the user login if valid, None otherwise.""" token_suffix = f"...{token[-4:]}" if token else "unknown" try: - rate_limit_remaining = await asyncio.to_thread(lambda: api.rate_limiting[-1]) + rate_limit_remaining = await github_api_call( + lambda: api.rate_limiting[-1], logger=self.logger, log_prefix=self.log_prefix + ) except Exception as ex: self.logger.warning( f"{self.log_prefix} Failed to get API rate limit for token ending in '{token_suffix}', " @@ -724,7 +745,9 @@ async def check_token(api: github.Github, token: str) -> str | None: return None try: - _api_user = await asyncio.to_thread(lambda: api.get_user().login) + _api_user = await github_api_call( + lambda: api.get_user().login, logger=self.logger, log_prefix=self.log_prefix + ) except Exception as ex: self.logger.exception( f"{self.log_prefix} Failed to get API user for token ending in '{token_suffix}', skipping. {ex}" @@ -924,7 +947,9 @@ def _sanitize_item(item: Any) -> str: async def get_pull_request(self, number: int | None = None) -> PullRequest | None: if number: self.logger.debug(f"{self.log_prefix} Attempting to get PR by number: {number}") - return await asyncio.to_thread(self.repository.get_pull, number) + return await github_api_call( + self.repository.get_pull, number, logger=self.logger, log_prefix=self.log_prefix + ) # Try to get PR number from hook_data self.logger.debug(f"{self.log_prefix} Attempting to get PR from webhook payload") @@ -934,7 +959,9 @@ async def get_pull_request(self, number: int | None = None) -> PullRequest | Non if pr_number: self.logger.debug(f"{self.log_prefix} Found PR number in payload: {pr_number}") try: - return await asyncio.to_thread(self.repository.get_pull, pr_number) + return await github_api_call( + self.repository.get_pull, pr_number, logger=self.logger, log_prefix=self.log_prefix + ) except GithubException as ex: self.logger.debug(f"{self.log_prefix} Failed to get PR {pr_number} from payload: {ex}") else: @@ -948,8 +975,15 @@ def _get_pr_head_sha(pr: PullRequest) -> str: if self.github_event == "check_run": head_sha = self.hook_data["check_run"]["head_sha"] self.logger.debug(f"{self.log_prefix} Searching open PRs for check_run head SHA: {head_sha}") - open_pulls = await asyncio.to_thread(lambda: list(self.repository.get_pulls(state="open"))) - head_shas = await asyncio.gather(*(asyncio.to_thread(_get_pr_head_sha, pr) for pr in open_pulls)) + open_pulls = await github_api_call( + lambda: list(self.repository.get_pulls(state="open")), logger=self.logger, log_prefix=self.log_prefix + ) + head_shas = await asyncio.gather( + *( + github_api_call(_get_pr_head_sha, pr, logger=self.logger, log_prefix=self.log_prefix) + for pr in open_pulls + ) + ) for _pull_request, pr_head_sha in zip(open_pulls, head_shas, strict=False): if pr_head_sha == head_sha: self.logger.debug( @@ -963,8 +997,15 @@ def _get_pr_head_sha(pr: PullRequest) -> str: if self.github_event == "status": sha = self.hook_data["sha"] self.logger.debug(f"{self.log_prefix} Searching open PRs for status SHA: {sha}") - open_pulls = await asyncio.to_thread(lambda: list(self.repository.get_pulls(state="open"))) - head_shas = await asyncio.gather(*(asyncio.to_thread(_get_pr_head_sha, pr) for pr in open_pulls)) + open_pulls = await github_api_call( + lambda: list(self.repository.get_pulls(state="open")), logger=self.logger, log_prefix=self.log_prefix + ) + head_shas = await asyncio.gather( + *( + github_api_call(_get_pr_head_sha, pr, logger=self.logger, log_prefix=self.log_prefix) + for pr in open_pulls + ) + ) for _pull_request, pr_head_sha in zip(open_pulls, head_shas, strict=False): if pr_head_sha == sha: self.logger.debug( @@ -978,12 +1019,18 @@ def _get_pr_head_sha(pr: PullRequest) -> str: commit: dict[str, Any] = self.hook_data.get("commit", {}) if commit: self.logger.debug(f"{self.log_prefix} Attempting to get PR from commit SHA: {commit.get('sha', 'unknown')}") - commit_obj = await asyncio.to_thread(self.repository.get_commit, commit["sha"]) + commit_obj = await github_api_call( + self.repository.get_commit, commit["sha"], logger=self.logger, log_prefix=self.log_prefix + ) with contextlib.suppress(Exception): - _pulls = await asyncio.to_thread(commit_obj.get_pulls) - if _pulls: - self.logger.debug(f"{self.log_prefix} Found PR from commit SHA: {_pulls[0].number}") - return _pulls[0] + pulls = await github_api_call( + lambda: list(commit_obj.get_pulls()), + logger=self.logger, + log_prefix=self.log_prefix, + ) + if pulls: + self.logger.debug(f"{self.log_prefix} Found PR from commit SHA: {pulls[0].number}") + return pulls[0] self.logger.debug(f"{self.log_prefix} No PR found for commit SHA") else: self.logger.debug(f"{self.log_prefix} No commit data in webhook payload") @@ -1045,16 +1092,56 @@ async def get_unresolved_review_threads(self, pr_number: int) -> list[dict[str, "prNumber": pr_number, "cursor": cursor, } - response = await client.post( - "https://api.github.com/graphql", - json={"query": query, "variables": variables}, - headers={ - "Authorization": f"Bearer {self.token}", - "Content-Type": "application/json", - }, - timeout=30.0, - ) - response.raise_for_status() + last_exception: Exception | None = None + for attempt in range(5): # max 4 retries + try: + response = await client.post( + "https://api.github.com/graphql", + json={"query": query, "variables": variables}, + headers={ + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json", + }, + timeout=30.0, + ) + response.raise_for_status() + last_exception = None + break + except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout) as ex: + last_exception = ex + if attempt == 4: + break + delay = 2 * (2**attempt) + self.logger.warning( + "%s GraphQL API call failed (attempt %d/%d), retrying in %ds: %s: %s", + self.log_prefix, + attempt + 1, + 5, + delay, + type(ex).__name__, + ex, + ) + await asyncio.sleep(delay) + except httpx.HTTPStatusError as ex: + if ex.response.status_code in (500, 502, 503, 504): + last_exception = ex + if attempt == 4: + break + delay = 2 * (2**attempt) + self.logger.warning( + "%s GraphQL API call failed (attempt %d/%d), retrying in %ds: HTTP %d", + self.log_prefix, + attempt + 1, + 5, + delay, + ex.response.status_code, + ) + await asyncio.sleep(delay) + else: + raise + + if last_exception is not None: + raise last_exception data = response.json() if "errors" in data: @@ -1096,8 +1183,12 @@ async def get_unresolved_review_threads(self, pr_number: int) -> list[dict[str, return unresolved_threads async def _get_last_commit(self, pull_request: PullRequest) -> Commit: - _commits = await asyncio.to_thread(pull_request.get_commits) - return list(_commits)[-1] + commits = await github_api_call( + lambda: list(pull_request.get_commits()), + logger=self.logger, + log_prefix=self.log_prefix, + ) + return commits[-1] @staticmethod def _comment_with_details(title: str, body: str) -> str: diff --git a/webhook_server/libs/handlers/check_run_handler.py b/webhook_server/libs/handlers/check_run_handler.py index 773af73d..54d2c047 100644 --- a/webhook_server/libs/handlers/check_run_handler.py +++ b/webhook_server/libs/handlers/check_run_handler.py @@ -21,6 +21,7 @@ TOX_STR, VERIFIED_LABEL_STR, ) +from webhook_server.utils.github_retry import github_api_call from webhook_server.utils.helpers import strip_ansi_codes if TYPE_CHECKING: @@ -83,7 +84,9 @@ async def process_pull_request_check_run_webhook_data(self, pull_request: PullRe label=AUTOMERGE_LABEL_STR, pull_request=pull_request ): try: - await asyncio.to_thread(pull_request.merge, merge_method="SQUASH") + await github_api_call( + pull_request.merge, merge_method="SQUASH", logger=self.logger, log_prefix=self.log_prefix + ) self.logger.info( f"{self.log_prefix} Successfully auto-merged pull request #{pull_request.number}" ) @@ -176,7 +179,12 @@ async def set_check_run_status( self.logger.debug( f"{self.log_prefix} Setting check run for {check_run}, status={status}, conclusion={conclusion}" ) - await asyncio.to_thread(self.github_webhook.repository_by_github_app.create_check_run, **kwargs) + await github_api_call( + self.github_webhook.repository_by_github_app.create_check_run, + **kwargs, + logger=self.logger, + log_prefix=self.log_prefix, + ) if conclusion in (SUCCESS_STR, IN_PROGRESS_STR): self.logger.info(msg) return @@ -187,7 +195,12 @@ async def set_check_run_status( self.logger.exception(f"{self.log_prefix} Failed to set check run status for {check_run}") kwargs["conclusion"] = FAILURE_STR kwargs["status"] = "completed" - await asyncio.to_thread(self.github_webhook.repository_by_github_app.create_check_run, **kwargs) + await github_api_call( + self.github_webhook.repository_by_github_app.create_check_run, + **kwargs, + logger=self.logger, + log_prefix=self.log_prefix, + ) def get_check_run_text(self, err: str, out: str) -> str: # Strip ANSI escape codes from output to prevent scrambled characters in GitHub UI @@ -249,7 +262,12 @@ def get_check_run_text(self, err: str, out: str) -> str: async def is_check_run_in_progress(self, check_run: str) -> bool: if self.github_webhook.last_commit: - for run in await asyncio.to_thread(self.github_webhook.last_commit.get_check_runs): + runs = await github_api_call( + lambda: list(self.github_webhook.last_commit.get_check_runs()), + logger=self.logger, + log_prefix=self.log_prefix, + ) + for run in runs: if run.name == check_run and run.status == IN_PROGRESS_STR: self.logger.debug(f"{self.log_prefix} Check run {check_run} is in progress.") return True @@ -392,7 +410,9 @@ async def all_required_status_checks(self, pull_request: PullRequest) -> list[st async def get_branch_required_status_checks(self, pull_request: PullRequest) -> list[str]: # Check if private repo first (cache to avoid repeated API calls) if self._repository_private is None: - self._repository_private = await asyncio.to_thread(lambda: self.repository.private) + self._repository_private = await github_api_call( + lambda: self.repository.private, logger=self.logger, log_prefix=self.log_prefix + ) if self._repository_private: self.logger.info( @@ -404,10 +424,22 @@ async def get_branch_required_status_checks(self, pull_request: PullRequest) -> if self._branch_required_status_checks is not None: return self._branch_required_status_checks - pull_request_branch = await asyncio.to_thread(self.repository.get_branch, pull_request.base.ref) - branch_protection = await asyncio.to_thread(pull_request_branch.get_protection) - branch_required_status_checks = await asyncio.to_thread( - lambda: branch_protection.required_status_checks.contexts + base_ref = await github_api_call( + lambda: pull_request.base.ref, + logger=self.logger, + log_prefix=self.log_prefix, + ) + pull_request_branch = await github_api_call( + self.repository.get_branch, + base_ref, + logger=self.logger, + log_prefix=self.log_prefix, + ) + branch_protection = await github_api_call( + pull_request_branch.get_protection, logger=self.logger, log_prefix=self.log_prefix + ) + branch_required_status_checks = await github_api_call( + lambda: branch_protection.required_status_checks.contexts, logger=self.logger, log_prefix=self.log_prefix ) self.logger.debug(f"{self.log_prefix} branch_required_status_checks: {branch_required_status_checks}") self._branch_required_status_checks = branch_required_status_checks diff --git a/webhook_server/libs/handlers/issue_comment_handler.py b/webhook_server/libs/handlers/issue_comment_handler.py index 1b67d02f..3ac7a859 100644 --- a/webhook_server/libs/handlers/issue_comment_handler.py +++ b/webhook_server/libs/handlers/issue_comment_handler.py @@ -35,6 +35,7 @@ VERIFIED_LABEL_STR, WIP_STR, ) +from webhook_server.utils.github_retry import github_api_call if TYPE_CHECKING: from webhook_server.libs.github_api import GithubWebhook @@ -96,7 +97,9 @@ async def process_comment_webhook_data(self, pull_request: PullRequest) -> None: # Execute all commands in parallel if _user_commands: # Cache draft status once to avoid repeated API calls - is_draft = await asyncio.to_thread(lambda: pull_request.draft) + is_draft = await github_api_call( + lambda: pull_request.draft, logger=self.logger, log_prefix=self.log_prefix + ) tasks: list[Coroutine[Any, Any, Any] | Task[Any]] = [] for user_command in _user_commands: @@ -118,10 +121,9 @@ async def process_comment_webhook_data(self, pull_request: PullRequest) -> None: failed_commands: list[tuple[str, Exception]] = [] for idx, result in enumerate(results): user_command = _user_commands[idx] + if isinstance(result, asyncio.CancelledError): + raise result if isinstance(result, Exception): - # Re-raise CancelledError immediately to allow cancellation to propagate - if isinstance(result, asyncio.CancelledError): - raise result self.logger.error(f"{self.log_prefix} Command execution failed: /{user_command} - {result}") failed_commands.append((user_command, result)) @@ -189,10 +191,12 @@ async def user_commands( f"{self.log_prefix} Command {_command} is not allowed on draft PRs. " f"Allowed commands: {allow_commands_on_draft}" ) - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"Command `/{_command}` is not allowed on draft PRs.\n" f"Allowed commands on draft PRs: {', '.join(allow_commands_on_draft)}", + logger=self.logger, + log_prefix=self.log_prefix, ) return @@ -221,7 +225,12 @@ async def user_commands( missing_command_arg_comment_msg: str = f"{_command} requires an argument" error_msg: str = f"{self.log_prefix} {missing_command_arg_comment_msg}" self.logger.debug(error_msg) - await asyncio.to_thread(pull_request.create_issue_comment, body=missing_command_arg_comment_msg) + await github_api_call( + pull_request.create_issue_comment, + body=missing_command_arg_comment_msg, + logger=self.logger, + log_prefix=self.log_prefix, + ) return if _command == AUTOMERGE_LABEL_STR: @@ -231,7 +240,9 @@ async def user_commands( ): msg = "Only maintainers or approvers can set pull request to auto-merge" self.logger.debug(f"{self.log_prefix} {msg}") - await asyncio.to_thread(pull_request.create_issue_comment, body=msg) + await github_api_call( + pull_request.create_issue_comment, body=msg, logger=self.logger, log_prefix=self.log_prefix + ) return await self.labels_handler._add_label(pull_request=pull_request, label=AUTOMERGE_LABEL_STR) @@ -245,7 +256,12 @@ async def user_commands( await self._add_reviewer_by_user_comment(pull_request=pull_request, reviewer=_args) elif _command == COMMAND_ADD_ALLOWED_USER_STR: - await asyncio.to_thread(pull_request.create_issue_comment, body=f"{_args} is now allowed to run commands") + await github_api_call( + pull_request.create_issue_comment, + body=f"{_args} is now allowed to run commands", + logger=self.logger, + log_prefix=self.log_prefix, + ) elif _command == COMMAND_ASSIGN_REVIEWERS_STR: await self.owners_file_handler.assign_reviewers(pull_request=pull_request) @@ -301,38 +317,55 @@ async def user_commands( msg = f"No {BUILD_AND_PUSH_CONTAINER_STR} configured for this repository" error_msg = f"{self.log_prefix} {msg}" self.logger.debug(error_msg) - await asyncio.to_thread(pull_request.create_issue_comment, msg) + await github_api_call( + pull_request.create_issue_comment, msg, logger=self.logger, log_prefix=self.log_prefix + ) elif _command == WIP_STR: wip_for_title: str = f"{WIP_STR.upper()}:" if remove: label_changed = await self.labels_handler._remove_label(pull_request=pull_request, label=WIP_STR) if label_changed: - pr_title = await asyncio.to_thread(lambda: pull_request.title) + pr_title = await github_api_call( + lambda: pull_request.title, logger=self.logger, log_prefix=self.log_prefix + ) # Case-insensitive check and removal of WIP prefix pr_title_upper = pr_title.upper() if pr_title_upper.startswith("WIP: "): new_title = pr_title[5:] # Remove "WIP: " (5 chars) - await asyncio.to_thread(pull_request.edit, title=new_title) + await github_api_call( + pull_request.edit, title=new_title, logger=self.logger, log_prefix=self.log_prefix + ) elif pr_title_upper.startswith("WIP:"): new_title = pr_title[4:] # Remove "WIP:" (4 chars) - await asyncio.to_thread(pull_request.edit, title=new_title) + await github_api_call( + pull_request.edit, title=new_title, logger=self.logger, log_prefix=self.log_prefix + ) else: label_changed = await self.labels_handler._add_label(pull_request=pull_request, label=WIP_STR) if label_changed: - pr_title = await asyncio.to_thread(lambda: pull_request.title) + pr_title = await github_api_call( + lambda: pull_request.title, logger=self.logger, log_prefix=self.log_prefix + ) # Case-insensitive check: only prepend if prefix is not already there (idempotent) if not pr_title.upper().startswith("WIP:"): - await asyncio.to_thread(pull_request.edit, title=f"{wip_for_title} {pr_title}") + await github_api_call( + pull_request.edit, + title=f"{wip_for_title} {pr_title}", + logger=self.logger, + log_prefix=self.log_prefix, + ) elif _command == HOLD_LABEL_STR: if reviewed_user not in self.owners_file_handler.all_pull_request_approvers: self.logger.debug( f"{self.log_prefix} {reviewed_user} is not an approver, not adding {HOLD_LABEL_STR} label" ) - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"{reviewed_user} is not part of the approver, only approvers can mark pull request with hold", + logger=self.logger, + log_prefix=self.log_prefix, ) else: if remove: @@ -367,23 +400,31 @@ async def user_commands( task.add_done_callback(_background_tasks.discard) async def create_comment_reaction(self, pull_request: PullRequest, issue_comment_id: int, reaction: str) -> None: - _comment = await asyncio.to_thread(pull_request.get_issue_comment, issue_comment_id) - await asyncio.to_thread(_comment.create_reaction, reaction) + _comment = await github_api_call( + pull_request.get_issue_comment, issue_comment_id, logger=self.logger, log_prefix=self.log_prefix + ) + await github_api_call(_comment.create_reaction, reaction, logger=self.logger, log_prefix=self.log_prefix) async def _add_reviewer_by_user_comment(self, pull_request: PullRequest, reviewer: str) -> None: reviewer = reviewer.strip("@") self.logger.info(f"{self.log_prefix} Adding reviewer {reviewer} by user comment") - repo_contributors = list(await asyncio.to_thread(self.repository.get_contributors)) + repo_contributors = await github_api_call( + lambda: list(self.repository.get_contributors()), + logger=self.logger, + log_prefix=self.log_prefix, + ) self.logger.debug(f"{self.log_prefix} Repo contributors are: {repo_contributors}") for contributer in repo_contributors: if contributer.login == reviewer: - await asyncio.to_thread(pull_request.create_review_request, [reviewer]) + await github_api_call( + pull_request.create_review_request, [reviewer], logger=self.logger, log_prefix=self.log_prefix + ) return _err = f"not adding reviewer {reviewer} by user comment, {reviewer} is not part of contributers" self.logger.debug(f"{self.log_prefix} {_err}") - await asyncio.to_thread(pull_request.create_issue_comment, _err) + await github_api_call(pull_request.create_issue_comment, _err, logger=self.logger, log_prefix=self.log_prefix) async def process_cherry_pick_command( self, pull_request: PullRequest, command_args: str, reviewed_user: str @@ -430,7 +471,9 @@ async def process_cherry_pick_command( for _target_branch in _target_branches: try: - await asyncio.to_thread(self.repository.get_branch, _target_branch) + await github_api_call( + self.repository.get_branch, _target_branch, logger=self.logger, log_prefix=self.log_prefix + ) _exits_target_branches.add(_target_branch) except Exception: _non_exits_target_branches_msg += f"Target branch `{_target_branch}` does not exist\n" @@ -441,13 +484,23 @@ async def process_cherry_pick_command( if _non_exits_target_branches_msg: self.logger.info(f"{self.log_prefix} {_non_exits_target_branches_msg}") - await asyncio.to_thread(pull_request.create_issue_comment, _non_exits_target_branches_msg) + await github_api_call( + pull_request.create_issue_comment, + _non_exits_target_branches_msg, + logger=self.logger, + log_prefix=self.log_prefix, + ) if not _exits_target_branches: return # Filter out branches that already have cherry-pick labels - existing_labels = {label.name for label in await asyncio.to_thread(lambda: list(pull_request.labels))} + existing_labels = { + label.name + for label in await github_api_call( + lambda: list(pull_request.labels), logger=self.logger, log_prefix=self.log_prefix + ) + } _already_cherry_picked: list[str] = [] _branches_to_process: set[str] = set() @@ -460,10 +513,12 @@ async def process_cherry_pick_command( if _already_cherry_picked: already_msg = ", ".join(f"`{b}`" for b in _already_cherry_picked) - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"Cherry-pick label already present for: {already_msg}\n" "To re-trigger, remove the cherry-pick label(s) and run the command again.", + logger=self.logger, + log_prefix=self.log_prefix, ) if not _branches_to_process: @@ -478,7 +533,9 @@ async def process_cherry_pick_command( """ self.logger.info(f"{self.log_prefix} {info_msg}") - await asyncio.to_thread(pull_request.create_issue_comment, info_msg) + await github_api_call( + pull_request.create_issue_comment, info_msg, logger=self.logger, log_prefix=self.log_prefix + ) else: for _branch in _branches_to_process: label_added = await self.labels_handler._add_label( @@ -518,7 +575,9 @@ async def process_retest_command( msg = "No test defined to retest" error_msg = f"{self.log_prefix} {msg}." self.logger.debug(error_msg) - await asyncio.to_thread(pull_request.create_issue_comment, msg) + await github_api_call( + pull_request.create_issue_comment, msg, logger=self.logger, log_prefix=self.log_prefix + ) return if "all" in command_args: @@ -526,7 +585,9 @@ async def process_retest_command( msg = "Invalid command. `all` cannot be used with other tests" error_msg = f"{self.log_prefix} {msg}." self.logger.debug(error_msg) - await asyncio.to_thread(pull_request.create_issue_comment, msg) + await github_api_call( + pull_request.create_issue_comment, msg, logger=self.logger, log_prefix=self.log_prefix + ) return else: @@ -547,7 +608,9 @@ async def process_retest_command( msg = f"No {' '.join(_not_supported_retests)} configured for this repository" error_msg = f"{self.log_prefix} {msg}." self.logger.debug(error_msg) - await asyncio.to_thread(pull_request.create_issue_comment, msg) + await github_api_call( + pull_request.create_issue_comment, msg, logger=self.logger, log_prefix=self.log_prefix + ) if _supported_retests: # Use runner_handler.run_retests() to avoid duplication diff --git a/webhook_server/libs/handlers/labels_handler.py b/webhook_server/libs/handlers/labels_handler.py index 49764cff..4bcef05d 100644 --- a/webhook_server/libs/handlers/labels_handler.py +++ b/webhook_server/libs/handlers/labels_handler.py @@ -27,6 +27,7 @@ STATIC_LABELS_DICT, WIP_STR, ) +from webhook_server.utils.github_retry import github_api_call if TYPE_CHECKING: from webhook_server.libs.github_api import GithubWebhook @@ -111,7 +112,11 @@ async def label_exists_in_pull_request(self, pull_request: PullRequest, label: s return label in await self.pull_request_labels_names(pull_request=pull_request) async def pull_request_labels_names(self, pull_request: PullRequest) -> list[str]: - labels = await asyncio.to_thread(pull_request.get_labels) + labels = await github_api_call( + lambda: list(pull_request.get_labels()), + logger=self.logger, + log_prefix=self.log_prefix, + ) return [lb.name for lb in labels] async def _remove_label(self, pull_request: PullRequest, label: str) -> bool: @@ -119,7 +124,9 @@ async def _remove_label(self, pull_request: PullRequest, label: str) -> bool: try: if await self.label_exists_in_pull_request(pull_request=pull_request, label=label): self.logger.info(f"{self.log_prefix} Removing label {label}") - await asyncio.to_thread(pull_request.remove_from_labels, label) + await github_api_call( + pull_request.remove_from_labels, label, logger=self.logger, log_prefix=self.log_prefix + ) success = await self.wait_for_label(pull_request=pull_request, label=label, exists=False) return success except Exception as exp: @@ -155,16 +162,22 @@ async def _add_label(self, pull_request: PullRequest, label: str) -> bool: _with_color_msg = f"repository label {label} with color {color}" try: - _repo_label = await asyncio.to_thread(self.repository.get_label, label) - await asyncio.to_thread(_repo_label.edit, name=_repo_label.name, color=color) + _repo_label = await github_api_call( + self.repository.get_label, label, logger=self.logger, log_prefix=self.log_prefix + ) + await github_api_call( + _repo_label.edit, name=_repo_label.name, color=color, logger=self.logger, log_prefix=self.log_prefix + ) self.logger.debug(f"{self.log_prefix} Edit {_with_color_msg}") except UnknownObjectException: self.logger.debug(f"{self.log_prefix} Add {_with_color_msg}") - await asyncio.to_thread(self.repository.create_label, name=label, color=color) + await github_api_call( + self.repository.create_label, name=label, color=color, logger=self.logger, log_prefix=self.log_prefix + ) self.logger.info(f"{self.log_prefix} Adding pull request label {label}") - await asyncio.to_thread(pull_request.add_to_labels, label) + await github_api_call(pull_request.add_to_labels, label, logger=self.logger, log_prefix=self.log_prefix) return await self.wait_for_label(pull_request=pull_request, label=label, exists=True) async def wait_for_label(self, pull_request: PullRequest, label: str, exists: bool) -> bool: @@ -297,8 +310,8 @@ def _get_custom_pr_size_thresholds(self) -> list[tuple[int | float, str, str]]: async def get_size(self, pull_request: PullRequest) -> str: """Calculates size label based on additions and deletions.""" additions, deletions = await asyncio.gather( - asyncio.to_thread(lambda: pull_request.additions), - asyncio.to_thread(lambda: pull_request.deletions), + github_api_call(lambda: pull_request.additions, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.deletions, logger=self.logger, log_prefix=self.log_prefix), ) size = additions + deletions self.logger.debug(f"{self.log_prefix} PR size is {size} (additions: {additions}, deletions: {deletions})") diff --git a/webhook_server/libs/handlers/owners_files_handler.py b/webhook_server/libs/handlers/owners_files_handler.py index 8e4409e9..25b18896 100644 --- a/webhook_server/libs/handlers/owners_files_handler.py +++ b/webhook_server/libs/handlers/owners_files_handler.py @@ -14,6 +14,7 @@ from github.Repository import Repository from webhook_server.utils.constants import COMMAND_ADD_ALLOWED_USER_STR, ROOT_APPROVERS_KEY +from webhook_server.utils.github_retry import github_api_call from webhook_server.utils.helpers import run_command if TYPE_CHECKING: @@ -99,10 +100,10 @@ async def list_changed_files(self, pull_request: PullRequest) -> list[str]: RuntimeError: If git diff command fails asyncio.CancelledError: Propagates cancellation (never caught) """ - # Get base and head SHAs (wrap property accesses in asyncio.to_thread) + # Get base and head SHAs (wrap property accesses in github_api_call for retry support) base_sha, head_sha = await asyncio.gather( - asyncio.to_thread(lambda: pull_request.base.sha), - asyncio.to_thread(lambda: pull_request.head.sha), + github_api_call(lambda: pull_request.base.sha, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.head.sha, logger=self.logger, log_prefix=self.log_prefix), ) # Run git diff command on cloned repository @@ -456,13 +457,18 @@ async def assign_reviewers(self, pull_request: PullRequest) -> None: if reviewer != pull_request.user.login: self.logger.debug(f"{self.log_prefix} Adding reviewer {reviewer}") try: - await asyncio.to_thread(pull_request.create_review_request, [reviewer]) + await github_api_call( + pull_request.create_review_request, [reviewer], logger=self.logger, log_prefix=self.log_prefix + ) assigned_count += 1 except GithubException as ex: self.logger.debug(f"{self.log_prefix} Failed to add reviewer {reviewer}. {ex}") - await asyncio.to_thread( - pull_request.create_issue_comment, f"{reviewer} can not be added as reviewer. {ex}" + await github_api_call( + pull_request.create_issue_comment, + f"{reviewer} can not be added as reviewer. {ex}", + logger=self.logger, + log_prefix=self.log_prefix, ) failed_count += 1 @@ -483,11 +489,12 @@ async def is_user_valid_to_run_commands(self, pull_request: PullRequest, reviewe self.logger.debug(f"{self.log_prefix} Valid users to run commands: {valid_users}") if reviewed_user not in valid_users: - for comment in [ - _comment - for _comment in await asyncio.to_thread(pull_request.get_issue_comments) - if _comment.user.login in allowed_user_to_approve - ]: + issue_comments = await github_api_call( + lambda: list(pull_request.get_issue_comments()), + logger=self.logger, + log_prefix=self.log_prefix, + ) + for comment in [_comment for _comment in issue_comments if _comment.user.login in allowed_user_to_approve]: if allow_user_comment in comment.body: self.logger.debug( f"{self.log_prefix} {reviewed_user} is approved by {comment.user.login} to run commands" @@ -495,7 +502,9 @@ async def is_user_valid_to_run_commands(self, pull_request: PullRequest, reviewe return True self.logger.debug(f"{self.log_prefix} {reviewed_user} is not in {valid_users}") - await asyncio.to_thread(pull_request.create_issue_comment, comment_msg) + await github_api_call( + pull_request.create_issue_comment, comment_msg, logger=self.logger, log_prefix=self.log_prefix + ) return False return True @@ -516,25 +525,31 @@ async def valid_users_to_run_commands(self) -> set[str]: async def get_all_repository_contributors(self) -> list[str]: contributors = await self.repository_contributors - return await asyncio.to_thread(lambda: [val.login for val in contributors]) + return await github_api_call( + lambda: [val.login for val in contributors], logger=self.logger, log_prefix=self.log_prefix + ) async def get_all_repository_collaborators(self) -> list[str]: collaborators = await self.repository_collaborators - return await asyncio.to_thread(lambda: [val.login for val in collaborators]) + return await github_api_call( + lambda: [val.login for val in collaborators], logger=self.logger, log_prefix=self.log_prefix + ) async def get_all_repository_maintainers(self) -> list[str]: maintainers: list[str] = [] # Fix #1: Convert PaginatedList to list in thread pool to avoid blocking during iteration collaborators = await self.repository_collaborators - collaborators_list = await asyncio.to_thread(lambda: list(collaborators)) + collaborators_list = await github_api_call( + lambda: list(collaborators), logger=self.logger, log_prefix=self.log_prefix + ) for user in collaborators_list: # Fix #2: Wrap permissions access in thread pool (property makes blocking API call) def get_user_permissions(u: NamedUser = user) -> Permissions: return u.permissions - permissions = await asyncio.to_thread(get_user_permissions) + permissions = await github_api_call(get_user_permissions, logger=self.logger, log_prefix=self.log_prefix) self.logger.debug(f"{self.log_prefix} User {user.login} permissions: {permissions}") if permissions.admin or permissions.maintain: @@ -545,8 +560,8 @@ def get_user_permissions(u: NamedUser = user) -> Permissions: @functools.cached_property async def repository_collaborators(self) -> PaginatedList[NamedUser]: - return await asyncio.to_thread(self.repository.get_collaborators) + return await github_api_call(self.repository.get_collaborators, logger=self.logger, log_prefix=self.log_prefix) @functools.cached_property async def repository_contributors(self) -> PaginatedList[NamedUser]: - return await asyncio.to_thread(self.repository.get_contributors) + return await github_api_call(self.repository.get_contributors, logger=self.logger, log_prefix=self.log_prefix) diff --git a/webhook_server/libs/handlers/pull_request_handler.py b/webhook_server/libs/handlers/pull_request_handler.py index fd844642..0f700233 100644 --- a/webhook_server/libs/handlers/pull_request_handler.py +++ b/webhook_server/libs/handlers/pull_request_handler.py @@ -42,6 +42,7 @@ VERIFIED_LABEL_STR, WIP_STR, ) +from webhook_server.utils.github_retry import github_api_call from webhook_server.utils.helpers import run_command if TYPE_CHECKING: @@ -220,7 +221,9 @@ async def _post_clean_rebase_comment( f"**Clean rebase detected** \u2014 no code changes compared to previous head (`{before_sha[:7]}`)." ) - await asyncio.to_thread(pull_request.create_issue_comment, body=comment_body) + await github_api_call( + pull_request.create_issue_comment, body=comment_body, logger=self.logger, log_prefix=self.log_prefix + ) except asyncio.CancelledError: raise except Exception: @@ -250,7 +253,14 @@ async def process_pull_request_webhook_data(self, pull_request: PullRequest) -> if hook_action in ("opened", "ready_for_review"): welcome_msg = self._prepare_welcome_comment() - tasks.append(asyncio.to_thread(pull_request.create_issue_comment, body=welcome_msg)) + tasks.append( + github_api_call( + pull_request.create_issue_comment, + body=welcome_msg, + logger=self.logger, + log_prefix=self.log_prefix, + ) + ) tasks.append(self.create_issue_for_new_pull_request(pull_request=pull_request)) tasks.append(self.set_wip_label_based_on_title(pull_request=pull_request)) @@ -325,7 +335,9 @@ async def process_pull_request_webhook_data(self, pull_request: PullRequest) -> if is_merged := pull_request_data.get("merged", False): self.logger.info(f"{self.log_prefix} PR is merged") - labels = await asyncio.to_thread(lambda: list(pull_request.labels)) + labels = await github_api_call( + lambda: list(pull_request.labels), logger=self.logger, log_prefix=self.log_prefix + ) if cherry_pick_labels := [ _label for _label in labels if _label.name.startswith(CHERRY_PICK_LABEL_PREFIX) ]: @@ -365,7 +377,9 @@ async def process_pull_request_webhook_data(self, pull_request: PullRequest) -> return self.logger.info(f"{self.log_prefix} PR {pull_request.number} {hook_action} with {labeled}") - labels = await asyncio.to_thread(lambda: list(pull_request.labels)) + labels = await github_api_call( + lambda: list(pull_request.labels), logger=self.logger, log_prefix=self.log_prefix + ) self.logger.debug(f"{self.log_prefix} PR labels are {labels}") _split_label = labeled.split(LABELS_SEPARATOR, 1) @@ -795,7 +809,9 @@ async def label_all_opened_pull_requests_merge_state_after_merged(self) -> None: self.logger.info(f"{self.log_prefix} Sleep for {time_sleep} seconds before getting all opened PRs") await asyncio.sleep(time_sleep) - pulls = await asyncio.to_thread(lambda: list(self.repository.get_pulls(state="open"))) + pulls = await github_api_call( + lambda: list(self.repository.get_pulls(state="open")), logger=self.logger, log_prefix=self.log_prefix + ) for pull_request in pulls: self.logger.info(f"{self.log_prefix} check label pull request after merge") await self.label_pull_request_by_merge_state(pull_request=pull_request, add_only=True) @@ -868,10 +884,12 @@ async def _delete_ghcr_tag_via_github_api( for scope in ("orgs", "users"): candidate_base = f"/{scope}/{owner_name}/packages/container/{package_name}" try: - _, versions = await asyncio.to_thread( + _, versions = await github_api_call( self.github_webhook.github_api.requester.requestJsonAndCheck, "GET", f"{candidate_base}/versions", + logger=self.logger, + log_prefix=self.log_prefix, ) package_api_base = candidate_base break @@ -905,8 +923,12 @@ async def _delete_ghcr_tag_via_github_api( # DELETE /{scope}/{owner}/packages/{package_type}/{package_name}/versions/{package_version_id} delete_url = f"{package_api_base}/versions/{version_to_delete_id}" try: - await asyncio.to_thread( - self.github_webhook.github_api.requester.requestJsonAndCheck, "DELETE", delete_url + await github_api_call( + self.github_webhook.github_api.requester.requestJsonAndCheck, + "DELETE", + delete_url, + logger=self.logger, + log_prefix=self.log_prefix, ) except GithubException as ex: if ex.status == 404: @@ -918,9 +940,17 @@ async def _delete_ghcr_tag_via_github_api( else: raise - await asyncio.to_thread( - pull_request.create_issue_comment, f"Successfully removed PR tag: {repository_full_tag}." - ) + try: + await github_api_call( + pull_request.create_issue_comment, + f"Successfully removed PR tag: {repository_full_tag}.", + logger=self.logger, + log_prefix=self.log_prefix, + ) + except Exception: + self.logger.exception( + f"{self.log_prefix} Tag cleanup succeeded, but PR notification failed: {repository_full_tag}" + ) except GithubException: self.logger.exception(f"{self.log_prefix} Failed to delete GHCR tag: {repository_full_tag}") @@ -963,9 +993,18 @@ async def _delete_registry_tag_via_regctl( redact_secrets=redact_values, ) if rc: - await asyncio.to_thread( - pull_request.create_issue_comment, f"Successfully removed PR tag: {repository_full_tag}." - ) + try: + await github_api_call( + pull_request.create_issue_comment, + f"Successfully removed PR tag: {repository_full_tag}.", + logger=self.logger, + log_prefix=self.log_prefix, + ) + except Exception: + self.logger.exception( + f"{self.log_prefix} Tag cleanup succeeded," + f" but PR notification failed: {repository_full_tag}" + ) else: self.logger.error( f"{self.log_prefix} Failed to delete tag: {repository_full_tag}. " @@ -981,10 +1020,17 @@ async def _delete_registry_tag_via_regctl( await self.runner_handler.run_podman_command(command="regctl registry logout") else: - await asyncio.to_thread( - pull_request.create_issue_comment, - f"Failed to delete tag: {repository_full_tag}. Please delete it manually.", - ) + try: + await github_api_call( + pull_request.create_issue_comment, + f"Failed to delete tag: {repository_full_tag}. Please delete it manually.", + logger=self.logger, + log_prefix=self.log_prefix, + ) + except Exception: + self.logger.exception( + f"{self.log_prefix} Tag cleanup failed, and PR notification also failed: {repository_full_tag}" + ) self.logger.error(f"{self.log_prefix} Failed to delete tag: {repository_full_tag}. OUT:{out}. ERR:{err}") async def close_issue_for_merged_or_closed_pr(self, pull_request: PullRequest, hook_action: str) -> None: @@ -996,19 +1042,23 @@ def _find_matching_issue() -> Any | None: return existing_issue return None - matching_issue = await asyncio.to_thread(_find_matching_issue) + matching_issue = await github_api_call(_find_matching_issue, logger=self.logger, log_prefix=self.log_prefix) if not matching_issue: return - pr_title = await asyncio.to_thread(lambda: pull_request.title) - issue_title = await asyncio.to_thread(lambda: matching_issue.title) + pr_title = await github_api_call(lambda: pull_request.title, logger=self.logger, log_prefix=self.log_prefix) + issue_title = await github_api_call( + lambda: matching_issue.title, logger=self.logger, log_prefix=self.log_prefix + ) self.logger.info(f"{self.log_prefix} Closing issue {issue_title} for PR: {pr_title}") - await asyncio.to_thread( + await github_api_call( matching_issue.create_comment, f"{self.log_prefix} Closing issue for PR: {pr_title}.\nPR was {hook_action}.", + logger=self.logger, + log_prefix=self.log_prefix, ) - await asyncio.to_thread(matching_issue.edit, state="closed") + await github_api_call(matching_issue.edit, state="closed", logger=self.logger, log_prefix=self.log_prefix) async def process_opened_or_synchronize_pull_request( self, pull_request: PullRequest, is_clean_rebase: bool = False, label_names: list[str] | None = None @@ -1118,11 +1168,16 @@ async def create_issue_for_new_pull_request(self, pull_request: PullRequest) -> return self.logger.info(f"{self.log_prefix} Creating issue for new PR: {pull_request.title}") - await asyncio.to_thread( + assignee_login = await github_api_call( + lambda: pull_request.user.login, logger=self.logger, log_prefix=self.log_prefix + ) + await github_api_call( self.repository.create_issue, title=self._generate_issue_title(pull_request=pull_request), body=self._generate_issue_body(pull_request=pull_request), - assignee=pull_request.user.login, + assignee=assignee_login, + logger=self.logger, + log_prefix=self.log_prefix, ) def _generate_issue_title(self, pull_request: PullRequest) -> str: @@ -1148,14 +1203,18 @@ async def set_pull_request_automerge(self, pull_request: PullRequest) -> None: if auto_merge: # AI-resolved cherry-picks should NEVER be auto-merged - labels = await asyncio.to_thread(lambda: list(pull_request.labels)) + labels = await github_api_call( + lambda: list(pull_request.labels), logger=self.logger, log_prefix=self.log_prefix + ) if any(label.name == AI_RESOLVED_CONFLICTS_LABEL for label in labels): if pull_request.raw_data.get("auto_merge"): try: self.logger.info( f"{self.log_prefix} AI-resolved cherry-pick has auto-merge enabled, disabling it" ) - await asyncio.to_thread(pull_request.disable_automerge) + await github_api_call( + pull_request.disable_automerge, logger=self.logger, log_prefix=self.log_prefix + ) except Exception: self.logger.exception( f"{self.log_prefix} Failed to disable auto-merge for AI-resolved cherry-pick" @@ -1172,7 +1231,12 @@ async def set_pull_request_automerge(self, pull_request: PullRequest) -> None: f"is part of auto merge enabled rules" ) - await asyncio.to_thread(pull_request.enable_automerge, merge_method="SQUASH") + await github_api_call( + pull_request.enable_automerge, + merge_method="SQUASH", + logger=self.logger, + log_prefix=self.log_prefix, + ) else: self.logger.debug(f"{self.log_prefix} is already set to auto merge") @@ -1181,7 +1245,9 @@ async def set_pull_request_automerge(self, pull_request: PullRequest) -> None: async def remove_labels_when_pull_request_sync(self, pull_request: PullRequest) -> None: tasks: list[Coroutine[Any, Any, Any]] = [] - labels = await asyncio.to_thread(lambda: list(pull_request.labels)) + labels = await github_api_call( + lambda: list(pull_request.labels), logger=self.logger, log_prefix=self.log_prefix + ) for _label in labels: _label_name = _label.name if ( @@ -1225,10 +1291,12 @@ async def _compare_branches(self, base_ref: str, head_ref_full: str) -> dict[str NOTE: This API does NOT return conflict information (mergeable/mergeable_state). """ try: - _, data = await asyncio.to_thread( + _, data = await github_api_call( self.repository._requester.requestJsonAndCheck, "GET", f"{self.repository.url}/compare/{base_ref}...{head_ref_full}", + logger=self.logger, + log_prefix=self.log_prefix, ) return data except GithubException: @@ -1271,7 +1339,9 @@ async def label_pull_request_by_merge_state(self, pull_request: PullRequest, add # Step 1: Check for conflicts first # GitHub may return mergeable=None while computing - poll until definitive - mergeable = await asyncio.to_thread(lambda: pull_request.mergeable) + mergeable = await github_api_call( + lambda: pull_request.mergeable, logger=self.logger, log_prefix=self.log_prefix + ) if mergeable is None: self.logger.debug( @@ -1291,7 +1361,7 @@ def _poll_mergeable() -> bool | None: return None # pragma: no cover try: - mergeable = await asyncio.to_thread(_poll_mergeable) + mergeable = await github_api_call(_poll_mergeable, logger=self.logger, log_prefix=self.log_prefix) except asyncio.CancelledError: raise except TimeoutExpiredError: @@ -1325,9 +1395,9 @@ def _poll_mergeable() -> bool | None: # Step 3: Check if needs rebase via Compare API base_ref, head_user_login, head_ref = await asyncio.gather( - asyncio.to_thread(lambda: pull_request.base.ref), - asyncio.to_thread(lambda: pull_request.head.user.login), - asyncio.to_thread(lambda: pull_request.head.ref), + github_api_call(lambda: pull_request.base.ref, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.head.user.login, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.head.ref, logger=self.logger, log_prefix=self.log_prefix), ) head_ref_full = f"{head_user_login}:{head_ref}" @@ -1373,7 +1443,9 @@ async def _process_verified_for_update_or_new_pull_request(self, pull_request: P return # Check if this is a cherry-picked PR - labels = await asyncio.to_thread(lambda: list(pull_request.labels)) + labels = await github_api_call( + lambda: list(pull_request.labels), logger=self.logger, log_prefix=self.log_prefix + ) # AI-resolved cherry-picks are NEVER auto-verified (takes precedence over auto-verify-cherry-picked-prs) is_ai_resolved = any(label.name == AI_RESOLVED_CONFLICTS_LABEL for label in labels) @@ -1427,13 +1499,23 @@ async def _sync_verified_check_for_clean_rebase(self, _pull_request: PullRequest async def add_pull_request_owner_as_assingee(self, pull_request: PullRequest) -> None: try: self.logger.info(f"{self.log_prefix} Adding PR owner as assignee") - await asyncio.to_thread(pull_request.add_to_assignees, pull_request.user.login) + assignee_login = await github_api_call( + lambda: pull_request.user.login, logger=self.logger, log_prefix=self.log_prefix + ) + await github_api_call( + pull_request.add_to_assignees, assignee_login, logger=self.logger, log_prefix=self.log_prefix + ) except Exception as exp: self.logger.debug(f"{self.log_prefix} Exception while adding PR owner as assignee: {exp}") if self.owners_file_handler.root_approvers: self.logger.debug(f"{self.log_prefix} Falling back to first approver as assignee") - await asyncio.to_thread(pull_request.add_to_assignees, self.owners_file_handler.root_approvers[0]) + await github_api_call( + pull_request.add_to_assignees, + self.owners_file_handler.root_approvers[0], + logger=self.logger, + log_prefix=self.log_prefix, + ) async def check_if_can_be_merged(self, pull_request: PullRequest) -> None: """ @@ -1467,8 +1549,16 @@ async def check_if_can_be_merged(self, pull_request: PullRequest) -> None: self.logger.info(f"{self.log_prefix} Check if {CAN_BE_MERGED_STR}.") await self.check_run_handler.set_check_in_progress(name=CAN_BE_MERGED_STR) # Fetch check runs, statuses, and optionally unresolved threads in parallel - _check_runs_task = asyncio.to_thread(lambda: list(self.github_webhook.last_commit.get_check_runs())) - _statuses_task = asyncio.to_thread(lambda: list(self.github_webhook.last_commit.get_statuses())) + _check_runs_task = github_api_call( + lambda: list(self.github_webhook.last_commit.get_check_runs()), + logger=self.logger, + log_prefix=self.log_prefix, + ) + _statuses_task = github_api_call( + lambda: list(self.github_webhook.last_commit.get_statuses()), + logger=self.logger, + log_prefix=self.log_prefix, + ) _unresolved_threads: list[dict[str, Any]] = [] if self.github_webhook.required_conversation_resolution: @@ -1488,7 +1578,9 @@ async def check_if_can_be_merged(self, pull_request: PullRequest) -> None: _labels = await self.labels_handler.pull_request_labels_names(pull_request=pull_request) self.logger.debug(f"{self.log_prefix} check if can be merged. PR labels are: {_labels}") - is_pr_mergable = await asyncio.to_thread(lambda: pull_request.mergeable) + is_pr_mergable = await github_api_call( + lambda: pull_request.mergeable, logger=self.logger, log_prefix=self.log_prefix + ) self.logger.debug(f"{self.log_prefix} PR mergeable is {is_pr_mergable}") if not is_pr_mergable: failure_output += f"PR is not mergeable: {is_pr_mergable}\n" @@ -1677,7 +1769,7 @@ def _check_labels_for_can_be_merged(self, labels: list[str]) -> str: return failure_output async def skip_if_pull_request_already_merged(self, pull_request: PullRequest) -> bool: - if pull_request and await asyncio.to_thread(lambda: pull_request.is_merged()): + if await github_api_call(lambda: pull_request.is_merged(), logger=self.logger, log_prefix=self.log_prefix): self.logger.info(f"{self.log_prefix}: PR is merged, not processing") return True @@ -1692,7 +1784,7 @@ def check_comments() -> bool: for comment in pull_request.get_issue_comments() ) - return await asyncio.to_thread(check_comments) + return await github_api_call(check_comments, logger=self.logger, log_prefix=self.log_prefix) async def regenerate_welcome_message(self, pull_request: PullRequest) -> None: """Regenerate and update the welcome message for this PR. @@ -1710,13 +1802,15 @@ def find_and_update_welcome_comment() -> bool: return True return False - updated = await asyncio.to_thread(find_and_update_welcome_comment) + updated = await github_api_call(find_and_update_welcome_comment, logger=self.logger, log_prefix=self.log_prefix) if updated: self.logger.info(f"{self.log_prefix} Updated existing welcome message") else: self.logger.info(f"{self.log_prefix} Creating new welcome message") - await asyncio.to_thread(pull_request.create_issue_comment, body=welcome_msg) + await github_api_call( + pull_request.create_issue_comment, body=welcome_msg, logger=self.logger, log_prefix=self.log_prefix + ) async def _tracking_issue_exists(self, pull_request: PullRequest) -> bool: """Check if tracking issue already exists for this PR.""" @@ -1725,7 +1819,7 @@ async def _tracking_issue_exists(self, pull_request: PullRequest) -> bool: def check_issues() -> bool: return any(issue.body == expected_body for issue in self.repository.get_issues()) - return await asyncio.to_thread(check_issues) + return await github_api_call(check_issues, logger=self.logger, log_prefix=self.log_prefix) async def process_new_or_reprocess_pull_request(self, pull_request: PullRequest) -> None: """Process a new or reprocessed PR - handles welcome message, tracking issue, and full workflow. @@ -1739,7 +1833,11 @@ async def process_new_or_reprocess_pull_request(self, pull_request: PullRequest) if not await self._welcome_comment_exists(pull_request=pull_request): self.logger.info(f"{self.log_prefix} Adding welcome message to PR") welcome_msg = self._prepare_welcome_comment() - tasks.append(asyncio.to_thread(pull_request.create_issue_comment, body=welcome_msg)) + tasks.append( + github_api_call( + pull_request.create_issue_comment, body=welcome_msg, logger=self.logger, log_prefix=self.log_prefix + ) + ) else: self.logger.info(f"{self.log_prefix} Welcome message already exists, skipping") @@ -1765,7 +1863,7 @@ async def process_new_or_reprocess_pull_request(self, pull_request: PullRequest) async def process_command_reprocess(self, pull_request: PullRequest) -> None: """Handle /reprocess command - triggers full PR workflow from scratch.""" # Check if PR is already merged - skip if merged - if await asyncio.to_thread(lambda: pull_request.is_merged()): + if await github_api_call(lambda: pull_request.is_merged(), logger=self.logger, log_prefix=self.log_prefix): self.logger.info(f"{self.log_prefix} PR is already merged, skipping reprocess") return diff --git a/webhook_server/libs/handlers/push_handler.py b/webhook_server/libs/handlers/push_handler.py index 16d56597..8a62cf50 100644 --- a/webhook_server/libs/handlers/push_handler.py +++ b/webhook_server/libs/handlers/push_handler.py @@ -1,4 +1,3 @@ -import asyncio import re import traceback from typing import TYPE_CHECKING @@ -7,6 +6,7 @@ from webhook_server.libs.handlers.check_run_handler import CheckRunHandler from webhook_server.libs.handlers.runner_handler import RunnerHandler +from webhook_server.utils.github_retry import github_api_call from webhook_server.utils.helpers import run_command from webhook_server.utils.notification_utils import send_slack_message @@ -67,12 +67,14 @@ async def _issue_on_error(_error: str) -> None: # Truncate to safe length (GitHub issue title limit is ~256 chars, use 250 for safety) if len(sanitized_title) > 250: sanitized_title = sanitized_title[:247] + "..." - await asyncio.to_thread( + await github_api_call( self.repository.create_issue, title=sanitized_title, body=f""" Publish to PYPI failed: `{_error}` """, + logger=self.logger, + log_prefix=self.log_prefix, ) self.logger.info(f"{self.log_prefix} Start uploading to pypi") diff --git a/webhook_server/libs/handlers/runner_handler.py b/webhook_server/libs/handlers/runner_handler.py index 9e1f9112..721d1d42 100644 --- a/webhook_server/libs/handlers/runner_handler.py +++ b/webhook_server/libs/handlers/runner_handler.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any import shortuuid -from github.Branch import Branch +from github import GithubException from github.PullRequest import PullRequest from github.Repository import Repository @@ -30,6 +30,7 @@ TOX_STR, ) from webhook_server.utils.github_repository_settings import get_repository_github_app_token +from webhook_server.utils.github_retry import github_api_call from webhook_server.utils.helpers import _redact_secrets, run_command from webhook_server.utils.notification_utils import send_slack_message @@ -96,8 +97,12 @@ async def _checkout_worktree( pr_number: int | None = None base_ref: str | None = None if pull_request: - pr_number = await asyncio.to_thread(lambda: pull_request.number) - base_ref = await asyncio.to_thread(lambda: pull_request.base.ref) + pr_number = await github_api_call( + lambda: pull_request.number, logger=self.logger, log_prefix=self.log_prefix + ) + base_ref = await github_api_call( + lambda: pull_request.base.ref, logger=self.logger, log_prefix=self.log_prefix + ) # Determine what to checkout checkout_target = "" @@ -154,7 +159,9 @@ async def _checkout_worktree( if success and pull_request and not is_merged and not tag_name and not skip_merge: merge_ref = base_ref if merge_ref is None: - merge_ref = await asyncio.to_thread(lambda: pull_request.base.ref) + merge_ref = await github_api_call( + lambda: pull_request.base.ref, logger=self.logger, log_prefix=self.log_prefix + ) git_cmd = f"git -C {worktree_path}" rc, out, err = await run_command( command=f"{git_cmd} merge origin/{merge_ref} -m 'Merge {merge_ref}'", @@ -278,8 +285,8 @@ async def run_tox(self, pull_request: PullRequest) -> None: python_ver = ( f"--python={self.github_webhook.tox_python_version}" if self.github_webhook.tox_python_version else "" ) - # Wrap PyGithub property access in asyncio.to_thread to avoid blocking - base_ref = await asyncio.to_thread(lambda: pull_request.base.ref) + # Wrap PyGithub property access to avoid blocking + base_ref = await github_api_call(lambda: pull_request.base.ref, logger=self.logger, log_prefix=self.log_prefix) _tox_tests = self.github_webhook.tox.get(base_ref, "") # Build tox command with {worktree_path} placeholder @@ -402,7 +409,9 @@ async def run_build_container( if push_rc: push_msg: str = f"New container for {_container_repository_and_tag} published" if pull_request: - await asyncio.to_thread(pull_request.create_issue_comment, push_msg) + await github_api_call( + pull_request.create_issue_comment, push_msg, logger=self.logger, log_prefix=self.log_prefix + ) if self.github_webhook.slack_webhook_url: message = f""" @@ -421,7 +430,9 @@ async def run_build_container( else: err_msg: str = f"Failed to build and push {_container_repository_and_tag}" if pull_request: - await asyncio.to_thread(pull_request.create_issue_comment, err_msg) + await github_api_call( + pull_request.create_issue_comment, err_msg, logger=self.logger, log_prefix=self.log_prefix + ) if self.github_webhook.slack_webhook_url: message = f""" @@ -547,7 +558,9 @@ async def run_conventional_title_check(self, pull_request: PullRequest) -> None: if suggestion_valid and ai_suggestion != title: self.logger.info(f"{self.log_prefix} AI fixing PR title from '{title}' to '{ai_suggestion}'") try: - await asyncio.to_thread(pull_request.edit, title=ai_suggestion) + await github_api_call( + pull_request.edit, title=ai_suggestion, logger=self.logger, log_prefix=self.log_prefix + ) output["title"] = "Conventional Title" output["summary"] = "PR title auto-fixed by AI" output["text"] = ( @@ -708,8 +721,19 @@ async def run_custom_check( ) await self.run_check(pull_request=pull_request, check_config=unified_config) - async def is_branch_exists(self, branch: str) -> Branch: - return await asyncio.to_thread(self.repository.get_branch, branch) + async def is_branch_exists(self, branch: str) -> bool: + try: + await github_api_call( + self.repository.get_branch, + branch, + logger=self.logger, + log_prefix=self.log_prefix, + ) + return True + except GithubException as ex: + if ex.status == 404: + return False + raise async def _resolve_cherry_pick_with_ai( self, @@ -833,8 +857,12 @@ async def cherry_pick( target_branch: str, assign_to_pr_owner: bool = True, ) -> None: - pr_author = await asyncio.to_thread(lambda: pull_request.user.login) - source_branch = await asyncio.to_thread(lambda: pull_request.base.ref) + pr_author = await github_api_call( + lambda: pull_request.user.login, logger=self.logger, log_prefix=self.log_prefix + ) + source_branch = await github_api_call( + lambda: pull_request.base.ref, logger=self.logger, log_prefix=self.log_prefix + ) self.logger.info( f"{self.log_prefix} Cherry-pick from {source_branch} to {target_branch}, PR owner: {pr_author}" @@ -844,7 +872,9 @@ async def cherry_pick( if not await self.is_branch_exists(branch=target_branch): err_msg = f"cherry-pick failed: {target_branch} does not exists" self.logger.error(err_msg) - await asyncio.to_thread(pull_request.create_issue_comment, err_msg) + await github_api_call( + pull_request.create_issue_comment, err_msg, logger=self.logger, log_prefix=self.log_prefix + ) else: await self.check_run_handler.set_check_in_progress(name=CHERRY_PICKED_LABEL) @@ -907,7 +937,7 @@ async def cherry_pick( ) self.logger.error(f"{self.log_prefix} Cherry pick failed: {redacted_out} --- {redacted_err}") local_branch_name = f"{pull_request.head.ref}-{target_branch}" - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"**Manual cherry-pick is needed**\nCherry pick failed for " f"{commit_hash} to {target_branch}:\n" @@ -922,6 +952,8 @@ async def cherry_pick( f"# git cherry-pick -m 1 {commit_hash}\n" f"git push origin {local_branch_name}\n" "```", + logger=self.logger, + log_prefix=self.log_prefix, ) return @@ -972,7 +1004,7 @@ async def cherry_pick( ) self.logger.error(f"{self.log_prefix} Cherry pick failed: {redacted_out} --- {redacted_err}") local_branch_name = f"{pull_request.head.ref}-{target_branch}" - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"**Manual cherry-pick is needed**\nCherry pick failed for " f"{commit_hash} to {target_branch}:\n" @@ -987,6 +1019,8 @@ async def cherry_pick( f"# git cherry-pick -m 1 {commit_hash}\n" f"git push origin {local_branch_name}\n" "```", + logger=self.logger, + log_prefix=self.log_prefix, ) return cherry_pick_had_conflicts = True @@ -1022,10 +1056,12 @@ async def cherry_pick( # Use GitHub App installation token for PR creation # so the PR is owned by the app bot, allowing repo collaborators to push try: - app_token = await asyncio.to_thread( + app_token = await github_api_call( get_repository_github_app_token, config_=self.github_webhook.config, repository_name=self.github_webhook.repository_full_name, + logger=self.logger, + log_prefix=self.log_prefix, ) except Exception: self.logger.exception( @@ -1045,9 +1081,8 @@ async def cherry_pick( if not rc: output["text"] = self.check_run_handler.get_check_run_text(err=err, out=out) await self.check_run_handler.set_check_failure(name=CHERRY_PICKED_LABEL, output=output) - await asyncio.to_thread( - pull_request.create_issue_comment, - f"**Cherry-pick branch created, but PR creation failed**\n" + body = ( + "**Cherry-pick branch created, but PR creation failed**\n" f"Branch `{new_branch_name}` was pushed to the repository.\n" f"Create the PR manually:\n" "```\n" @@ -1058,7 +1093,13 @@ async def cherry_pick( + (f" --label {AI_RESOLVED_CONFLICTS_LABEL}" if cherry_pick_had_conflicts else "") + f" --title '{pr_title}'" f" --body '{pr_body}'\n" - "```", + "```" + ) + await github_api_call( + pull_request.create_issue_comment, + body, + logger=self.logger, + log_prefix=self.log_prefix, ) redacted_out = _redact_secrets( out, @@ -1081,7 +1122,9 @@ async def cherry_pick( # Get the cherry-pick PR object try: pr_number = int(cherry_pick_pr_url.rstrip("/").split("/")[-1]) - cherry_pick_pr = await asyncio.to_thread(self.repository.get_pull, pr_number) + cherry_pick_pr = await github_api_call( + self.repository.get_pull, pr_number, logger=self.logger, log_prefix=self.log_prefix + ) except Exception: self.logger.exception( f"{self.log_prefix} Failed to get cherry-pick PR from URL: {cherry_pick_pr_url}" @@ -1092,7 +1135,12 @@ async def cherry_pick( # Assign the PR to the original author (or fallback approver) if assign_to_pr_owner: try: - await asyncio.to_thread(cherry_pick_pr.add_to_assignees, pr_author) + await github_api_call( + cherry_pick_pr.add_to_assignees, + pr_author, + logger=self.logger, + log_prefix=self.log_prefix, + ) self.logger.info( f"{self.log_prefix} Assigned {pr_author} to cherry-pick PR #{cherry_pick_pr.number}" ) @@ -1104,7 +1152,12 @@ async def cherry_pick( try: fallback_approvers = self.owners_file_handler.root_approvers if fallback_approvers: - await asyncio.to_thread(cherry_pick_pr.add_to_assignees, fallback_approvers[0]) + await github_api_call( + cherry_pick_pr.add_to_assignees, + fallback_approvers[0], + logger=self.logger, + log_prefix=self.log_prefix, + ) self.logger.info( f"{self.log_prefix} Assigned fallback approver" f" {fallback_approvers[0]} to cherry-pick PR #{cherry_pick_pr.number}" @@ -1125,7 +1178,9 @@ async def cherry_pick( labels_to_add = [cherry_picked_label] if cherry_pick_had_conflicts: labels_to_add.append(AI_RESOLVED_CONFLICTS_LABEL) - await asyncio.to_thread(cherry_pick_pr.add_to_labels, *labels_to_add) + await github_api_call( + cherry_pick_pr.add_to_labels, *labels_to_add, logger=self.logger, log_prefix=self.log_prefix + ) self.logger.info( f"{self.log_prefix} Added labels {labels_to_add} to cherry-pick PR #{cherry_pick_pr.number}" ) @@ -1133,31 +1188,40 @@ async def cherry_pick( self.logger.exception(f"{self.log_prefix} Failed to add labels to cherry-pick PR") # Labels are critical for auto-verify skip — warn if they couldn't be added try: - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"**Warning:** Failed to add labels to cherry-pick PR {cherry_pick_pr_url}. " f"Please manually add the `{cherry_picked_label}` label" + (f" and `{AI_RESOLVED_CONFLICTS_LABEL}` label" if cherry_pick_had_conflicts else "") + " to ensure correct auto-verify behavior.", + logger=self.logger, + log_prefix=self.log_prefix, ) except Exception: self.logger.exception(f"{self.log_prefix} Failed to post label warning comment") # Request review from original PR author (independent of label success) try: - await asyncio.to_thread(cherry_pick_pr.create_review_request, reviewers=[pr_author]) + await github_api_call( + cherry_pick_pr.create_review_request, + reviewers=[pr_author], + logger=self.logger, + log_prefix=self.log_prefix, + ) except Exception: self.logger.debug( f"{self.log_prefix} Could not request review from {pr_author} (may not be a collaborator)" ) else: # PR was created but we couldn't fetch it — labels/reviewer not added - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"**Warning:** Cherry-pick PR was created ({cherry_pick_pr_url}) but failed to add labels. " f"Please manually add the `{cherry_picked_label}` label" + (f" and `{AI_RESOLVED_CONFLICTS_LABEL}` label" if cherry_pick_had_conflicts else "") + " to ensure correct auto-verify behavior.", + logger=self.logger, + log_prefix=self.log_prefix, ) output["text"] = self.check_run_handler.get_check_run_text(err=err, out=out) @@ -1167,17 +1231,21 @@ async def cherry_pick( ai_config = self.github_webhook.ai_features ai_result = get_ai_config(ai_config) ai_provider, ai_model = ai_result if ai_result else ("unknown", "unknown") - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"**Cherry-pick conflicts were resolved by AI**\n\n" f"Cherry-picked PR {pull_request.title} into {target_branch}: {cherry_pick_pr_url}\n" f"Conflicts were automatically resolved by AI ({ai_provider}/{ai_model}).\n\n" f"**Manual verification is required** — please review the changes and test before merging.", + logger=self.logger, + log_prefix=self.log_prefix, ) else: - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"Cherry-picked PR {pull_request.title} into {target_branch}: {cherry_pick_pr_url}", + logger=self.logger, + log_prefix=self.log_prefix, ) async def run_retests(self, supported_retests: list[str], pull_request: PullRequest) -> None: diff --git a/webhook_server/libs/test_oracle.py b/webhook_server/libs/test_oracle.py index 7dc929a0..d365d6ce 100644 --- a/webhook_server/libs/test_oracle.py +++ b/webhook_server/libs/test_oracle.py @@ -5,6 +5,8 @@ import httpx +from webhook_server.utils.github_retry import github_api_call + if TYPE_CHECKING: from github.PullRequest import PullRequest @@ -57,16 +59,20 @@ async def call_test_oracle( msg = f"Test Oracle server at {server_url} is not responding{status_info}, skipping test analysis" github_webhook.logger.warning(f"{log_prefix} {msg}") try: - await asyncio.to_thread( + await github_api_call( pull_request.create_issue_comment, f"Test Oracle server is not responding{status_info}, skipping test analysis", + logger=github_webhook.logger, + log_prefix=log_prefix, ) except Exception: github_webhook.logger.exception(f"{log_prefix} Failed to post health check comment") return # Build analyze payload - pr_url: str = await asyncio.to_thread(lambda: pull_request.html_url) + pr_url: str = await github_api_call( + lambda: pull_request.html_url, logger=github_webhook.logger, log_prefix=log_prefix + ) payload: dict[str, Any] = { "pr_url": pr_url, "ai_provider": config["ai-provider"], diff --git a/webhook_server/tests/test_clean_rebase_detection.py b/webhook_server/tests/test_clean_rebase_detection.py index 51cb9d03..538411fc 100644 --- a/webhook_server/tests/test_clean_rebase_detection.py +++ b/webhook_server/tests/test_clean_rebase_detection.py @@ -659,7 +659,7 @@ async def test_synchronize_clean_rebase_posts_comment_with_preserved_labels( # Labels should NOT be fetched via API call - they come from webhook payload handler.labels_handler.pull_request_labels_names.assert_not_called() - # create_issue_comment is called via asyncio.to_thread which executes it + # create_issue_comment is called via github_api_call which executes it mock_pull_request.create_issue_comment.assert_called_once() comment_body = mock_pull_request.create_issue_comment.call_args.kwargs["body"] assert "Clean rebase detected" in comment_body diff --git a/webhook_server/tests/test_github_retry.py b/webhook_server/tests/test_github_retry.py new file mode 100644 index 00000000..576aac3a --- /dev/null +++ b/webhook_server/tests/test_github_retry.py @@ -0,0 +1,300 @@ +"""Tests for webhook_server.utils.github_retry module.""" + +import asyncio +import logging +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from github.GithubException import BadCredentialsException, GithubException, UnknownObjectException +from requests.exceptions import ConnectionError as RequestsConnectionError +from urllib3.exceptions import MaxRetryError, ResponseError + +from webhook_server.utils.github_retry import _is_retryable, github_api_call + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +LOG_PREFIX = "[test/repo #1]" + + +@pytest.fixture() +def mock_logger() -> logging.Logger: + """Provide a mock logger for github_api_call tests.""" + return Mock(spec=logging.Logger) + + +# --------------------------------------------------------------------------- +# _is_retryable unit tests +# --------------------------------------------------------------------------- + + +class TestIsRetryable: + def test_github_500_is_retryable(self): + ex = GithubException(status=500, data={"message": "Internal Server Error"}) + assert _is_retryable(ex) is True + + def test_github_502_is_retryable(self): + ex = GithubException(status=502, data={"message": "Bad Gateway"}) + assert _is_retryable(ex) is True + + def test_github_503_is_retryable(self): + ex = GithubException(status=503, data={"message": "Service Unavailable"}) + assert _is_retryable(ex) is True + + def test_github_exception_504_retryable(self): + ex = GithubException(status=504, data={"message": "Gateway Timeout"}) + assert _is_retryable(ex) is True + + def test_github_404_not_retryable(self): + ex = UnknownObjectException(status=404, data={"message": "Not Found"}) + assert _is_retryable(ex) is False + + def test_github_401_not_retryable(self): + ex = BadCredentialsException(status=401, data={"message": "Bad credentials"}) + assert _is_retryable(ex) is False + + def test_github_403_not_retryable(self): + ex = GithubException(status=403, data={"message": "Forbidden"}) + assert _is_retryable(ex) is False + + def test_github_422_not_retryable(self): + ex = GithubException(status=422, data={"message": "Unprocessable Entity"}) + assert _is_retryable(ex) is False + + def test_github_unknown_status_not_retryable(self): + ex = GithubException(status=418, data={"message": "I'm a teapot"}) + assert _is_retryable(ex) is False + + def test_requests_connection_error_is_retryable(self): + ex = RequestsConnectionError("Connection refused") + assert _is_retryable(ex) is True + + def test_urllib3_max_retry_error_is_retryable(self): + ex = MaxRetryError(pool=Mock(), url="https://api.github.com") + assert _is_retryable(ex) is True + + def test_exception_with_500_error_responses_substring(self): + ex = Exception("Got 500 error responses from server") + assert _is_retryable(ex) is True + + def test_exception_with_max_retries_exceeded_substring(self): + ex = Exception("Max retries exceeded with url: /repos/org/repo") + assert _is_retryable(ex) is True + + def test_generic_value_error_not_retryable(self): + ex = ValueError("something went wrong") + assert _is_retryable(ex) is False + + def test_generic_runtime_error_not_retryable(self): + ex = RuntimeError("unexpected failure") + assert _is_retryable(ex) is False + + def test_urllib3_response_error_is_retryable(self): + ex = ResponseError("connection reset by peer") + assert _is_retryable(ex) is True + + +# --------------------------------------------------------------------------- +# github_api_call tests +# --------------------------------------------------------------------------- + + +class TestGithubApiCall: + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_successful_call_first_attempt(self, mock_sleep, mock_logger): + func = Mock(return_value=42) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == 42 + func.assert_called_once() + mock_sleep.assert_not_called() + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_successful_call_after_transient_failure(self, mock_sleep, mock_logger): + func = Mock( + side_effect=[ + GithubException(status=500, data={"message": "Internal Server Error"}), + "success", + ] + ) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "success" + assert func.call_count == 2 + mock_sleep.assert_called_once_with(2) + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_github_500(self, mock_sleep, mock_logger): + ex = GithubException(status=500, data={"message": "Internal Server Error"}) + func = Mock(side_effect=[ex, ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 3 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_github_502(self, mock_sleep, mock_logger): + ex = GithubException(status=502, data={"message": "Bad Gateway"}) + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_github_503(self, mock_sleep, mock_logger): + ex = GithubException(status=503, data={"message": "Service Unavailable"}) + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_github_exception_504(self, mock_sleep, mock_logger): + ex = GithubException(status=504, data={"message": "Gateway Timeout"}) + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_no_retry_on_404(self, mock_sleep, mock_logger): + ex = UnknownObjectException(status=404, data={"message": "Not Found"}) + func = Mock(side_effect=ex) + with pytest.raises(UnknownObjectException): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + func.assert_called_once() + mock_sleep.assert_not_called() + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_no_retry_on_401(self, mock_sleep, mock_logger): + ex = BadCredentialsException(status=401, data={"message": "Bad credentials"}) + func = Mock(side_effect=ex) + with pytest.raises(BadCredentialsException): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + func.assert_called_once() + mock_sleep.assert_not_called() + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_no_retry_on_403(self, mock_sleep, mock_logger): + ex = GithubException(status=403, data={"message": "Forbidden"}) + func = Mock(side_effect=ex) + with pytest.raises(GithubException): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + func.assert_called_once() + mock_sleep.assert_not_called() + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_no_retry_on_422(self, mock_sleep, mock_logger): + ex = GithubException(status=422, data={"message": "Unprocessable Entity"}) + func = Mock(side_effect=ex) + with pytest.raises(GithubException): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + func.assert_called_once() + mock_sleep.assert_not_called() + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_requests_connection_error(self, mock_sleep, mock_logger): + ex = RequestsConnectionError("Connection refused") + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_max_retry_error(self, mock_sleep, mock_logger): + ex = MaxRetryError(pool=Mock(), url="https://api.github.com") + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_500_error_responses_message(self, mock_sleep, mock_logger): + ex = Exception("Got 500 error responses from server") + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_max_retries_exceeded_message(self, mock_sleep, mock_logger): + ex = Exception("Max retries exceeded with url: /repos/org/repo") + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_max_retries_exhausted(self, mock_sleep, mock_logger): + ex = GithubException(status=500, data={"message": "Internal Server Error"}) + func = Mock(side_effect=ex) + with pytest.raises(GithubException, match="Internal Server Error"): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + # _MAX_RETRIES + 1 total attempts = 5 + assert func.call_count == 5 + assert mock_sleep.call_count == 4 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_cancelled_error_always_reraised(self, mock_sleep, mock_logger): + func = Mock(side_effect=asyncio.CancelledError) + with pytest.raises(asyncio.CancelledError): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + func.assert_called_once() + mock_sleep.assert_not_called() + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_lambda_calls_work(self, mock_sleep, mock_logger): + result = await github_api_call(lambda: "lambda_value", logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "lambda_value" + mock_sleep.assert_not_called() + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_kwargs_are_forwarded(self, mock_sleep, mock_logger): + func = Mock(return_value="done") + result = await github_api_call( + func, "pos_arg", logger=mock_logger, log_prefix=LOG_PREFIX, key1="val1", key2="val2" + ) + assert result == "done" + func.assert_called_once_with("pos_arg", key1="val1", key2="val2") + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_exponential_backoff_timing(self, mock_sleep, mock_logger): + ex = GithubException(status=500, data={"message": "Internal Server Error"}) + func = Mock(side_effect=ex) + with pytest.raises(GithubException): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + + # Delays: 2*2^0=2, 2*2^1=4, 2*2^2=8, 2*2^3=16 + sleep_delays = [call.args[0] for call in mock_sleep.call_args_list] + assert sleep_delays == [2, 4, 8, 16] + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_warning_logged_on_each_retry(self, mock_sleep, mock_logger): + ex = GithubException(status=500, data={"message": "Internal Server Error"}) + func = Mock(side_effect=[ex, ex, "ok"]) + + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + + assert result == "ok" + assert mock_logger.warning.call_count == 2 + # Verify attempt numbers in log messages + first_call_args = mock_logger.warning.call_args_list[0] + assert first_call_args[0][1] == f"{LOG_PREFIX} " # log_prefix + assert first_call_args[0][2] == 1 # attempt 1 + assert first_call_args[0][3] == 5 # total attempts (MAX_RETRIES + 1) + second_call_args = mock_logger.warning.call_args_list[1] + assert second_call_args[0][2] == 2 # attempt 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_retry_on_response_error(self, mock_sleep, mock_logger): + ex = ResponseError("too many 500 error responses") + func = Mock(side_effect=[ex, "ok"]) + result = await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + assert result == "ok" + assert func.call_count == 2 + + @patch("webhook_server.utils.github_retry.asyncio.sleep", new_callable=AsyncMock) + async def test_non_retryable_generic_exception(self, mock_sleep, mock_logger): + func = Mock(side_effect=ValueError("bad value")) + with pytest.raises(ValueError, match="bad value"): + await github_api_call(func, logger=mock_logger, log_prefix=LOG_PREFIX) + func.assert_called_once() + mock_sleep.assert_not_called() diff --git a/webhook_server/tests/test_runner_handler.py b/webhook_server/tests/test_runner_handler.py index 203ede24..f0c729e1 100644 --- a/webhook_server/tests/test_runner_handler.py +++ b/webhook_server/tests/test_runner_handler.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from github import GithubException from webhook_server.libs.handlers.runner_handler import CheckConfig, RunnerHandler from webhook_server.utils.constants import ( @@ -972,11 +973,17 @@ async def mock_checkout_worktree(**kwargs: Any) -> AsyncGenerator[tuple[bool, st @pytest.mark.asyncio async def test_is_branch_exists(self, runner_handler: RunnerHandler) -> None: - """Test is_branch_exists.""" + """Test is_branch_exists returns True when branch exists, False on 404.""" mock_branch = Mock() with patch("asyncio.to_thread", new=AsyncMock(return_value=mock_branch)): result = await runner_handler.is_branch_exists("main") - assert result == mock_branch + assert result is True + + with patch( + "asyncio.to_thread", new=AsyncMock(side_effect=GithubException(status=404, data="not found", headers={})) + ): + result = await runner_handler.is_branch_exists("non-existent-branch") + assert result is False @pytest.mark.asyncio async def test_cherry_pick_branch_not_exists(self, runner_handler: RunnerHandler, mock_pull_request: Mock) -> None: diff --git a/webhook_server/tests/test_test_oracle.py b/webhook_server/tests/test_test_oracle.py index a64d5c17..457143df 100644 --- a/webhook_server/tests/test_test_oracle.py +++ b/webhook_server/tests/test_test_oracle.py @@ -143,7 +143,7 @@ async def test_analyze_error_logs_only(self, mock_github_webhook: Mock, mock_pul with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread: await call_test_oracle(github_webhook=mock_github_webhook, pull_request=mock_pull_request) - # asyncio.to_thread is called once for pull_request.html_url, but not for posting a comment + # github_api_call is called once for pull_request.html_url, but not for posting a comment assert mock_to_thread.call_count == 1 mock_github_webhook.logger.error.assert_called() @@ -161,7 +161,7 @@ async def test_analyze_network_error_logs_only(self, mock_github_webhook: Mock, with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread: await call_test_oracle(github_webhook=mock_github_webhook, pull_request=mock_pull_request) - # asyncio.to_thread is called once for pull_request.html_url, but not for posting a comment + # github_api_call is called once for pull_request.html_url, but not for posting a comment assert mock_to_thread.call_count == 1 mock_github_webhook.logger.error.assert_called() diff --git a/webhook_server/utils/github_retry.py b/webhook_server/utils/github_retry.py new file mode 100644 index 00000000..f9fa7293 --- /dev/null +++ b/webhook_server/utils/github_retry.py @@ -0,0 +1,128 @@ +"""Retry wrapper for GitHub API calls made via asyncio.to_thread. + +PyGithub is synchronous and must be wrapped with asyncio.to_thread() for +non-blocking operation. When GitHub's API returns transient HTTP 500/502/503/504 +errors, urllib3's built-in retries exhaust quickly and raise ConnectionError +or ResponseError. This module provides application-level retry with +exponential backoff as a drop-in replacement for asyncio.to_thread(). + +Usage:: + + from webhook_server.utils.github_retry import github_api_call + + # Method calls with arguments + await github_api_call( + pull_request.create_issue_comment, + body="hello", + logger=self.logger, + log_prefix=self.log_prefix, + ) + + # Property access via lambda + await github_api_call( + lambda: pull_request.draft, + logger=self.logger, + log_prefix=self.log_prefix, + ) +""" + +import asyncio +import logging +from collections.abc import Callable +from typing import Any + +from github.GithubException import BadCredentialsException, GithubException, UnknownObjectException +from requests.exceptions import ConnectionError as RequestsConnectionError +from urllib3.exceptions import MaxRetryError, ResponseError + +_RETRYABLE_STATUS_CODES = frozenset({500, 502, 503, 504}) +_PERMANENT_STATUS_CODES = frozenset({401, 403, 404, 422}) +_RETRYABLE_SUBSTRINGS = ("500 error responses", "Max retries exceeded") + +_MAX_RETRIES = 4 +_BASE_DELAY = 2 + + +def _is_retryable(ex: Exception) -> bool: + """Determine whether an exception is retryable.""" + if isinstance(ex, (BadCredentialsException, UnknownObjectException)): + return False + + if isinstance(ex, GithubException): + if ex.status in _PERMANENT_STATUS_CODES: + return False + if ex.status in _RETRYABLE_STATUS_CODES: + return True + return False + + if isinstance(ex, (RequestsConnectionError, MaxRetryError, ResponseError)): + return True + + error_str = str(ex) + return any(substring in error_str for substring in _RETRYABLE_SUBSTRINGS) + + +async def github_api_call[T]( + func: Callable[..., T], + *args: Any, + logger: logging.Logger, + log_prefix: str, + **kwargs: Any, +) -> T: + """Execute a GitHub API call via asyncio.to_thread with retry on transient errors. + + Drop-in replacement for ``asyncio.to_thread(func, *args, **kwargs)`` that + retries on transient GitHub API errors (HTTP 500, 502, 503, 504) with exponential + backoff. + + Args: + func: The callable to execute in a thread. Can be a bound method + (e.g. ``pull_request.create_issue_comment``) or a lambda + (e.g. ``lambda: pull_request.draft``). + *args: Positional arguments forwarded to *func*. + logger: Logger instance used for retry warning messages. + log_prefix: Prefix string prepended to retry warning messages + (``self.log_prefix`` from the caller). + **kwargs: Keyword arguments forwarded to *func*. + + Returns: + The return value of *func*. + + Raises: + The original exception after all retries are exhausted, or immediately + for non-retryable errors (401, 403, 404, 422) and + ``asyncio.CancelledError``. + """ + # Note: retries may re-execute non-idempotent operations (e.g., create_issue_comment) + # if GitHub returned a transient error after partial completion. This is an accepted + # tradeoff — rare duplicate side effects are preferable to hard failures. + last_exception: Exception | None = None + + for attempt in range(_MAX_RETRIES + 1): + try: + return await asyncio.to_thread(func, *args, **kwargs) + except asyncio.CancelledError: + raise + except Exception as ex: + last_exception = ex + + if not _is_retryable(ex): + raise + + if attempt == _MAX_RETRIES: + break + + delay = _BASE_DELAY * (2**attempt) + logger.warning( + "%sGitHub API call failed (attempt %d/%d), retrying in %ds: %s: %s", + f"{log_prefix} " if log_prefix else "", + attempt + 1, + _MAX_RETRIES + 1, + delay, + type(ex).__name__, + ex, + ) + await asyncio.sleep(delay) + + assert last_exception is not None # noqa: S101 + raise last_exception