diff --git a/webhook_server/libs/handlers/pull_request_handler.py b/webhook_server/libs/handlers/pull_request_handler.py index 1776e906..0e3da70b 100644 --- a/webhook_server/libs/handlers/pull_request_handler.py +++ b/webhook_server/libs/handlers/pull_request_handler.py @@ -732,21 +732,132 @@ async def remove_labels_when_pull_request_sync(self, pull_request: PullRequest) if isinstance(result, Exception): self.logger.error(f"{self.log_prefix} Async task failed: {result}") + async def _compare_branches(self, base_ref: str, head_ref_full: str) -> dict[str, Any] | None: + """Call GitHub Compare API to get branch comparison data for rebase detection. + + This API is used ONLY for detecting if a PR is behind/diverged from base branch. + It does NOT provide conflict information - use pull_request.mergeable for conflicts. + + Args: + base_ref: Base branch reference (e.g., "main"). + head_ref_full: Full head reference including owner (e.g., "user:branch"). + + Returns: + Compare API response data or None if API call fails. + + Compare API Reference: + GET /repos/{owner}/{repo}/compare/{base}...{head} + Response fields used: + - behind_by: int - commits behind base branch + - status: str - "ahead", "behind", "diverged", "identical" + + NOTE: This API does NOT return conflict information (mergeable/mergeable_state). + """ + try: + _headers, data = await asyncio.to_thread( + self.repository._requester.requestJsonAndCheck, + "GET", + f"{self.repository.url}/compare/{base_ref}...{head_ref_full}", + ) + return data + except GithubException: + self.logger.exception(f"{self.log_prefix} Failed to call Compare API for {base_ref}...{head_ref_full}") + return None + except Exception: + self.logger.exception(f"{self.log_prefix} Unexpected error calling Compare API") + return None + async def label_pull_request_by_merge_state(self, pull_request: PullRequest) -> None: - merge_state = await asyncio.to_thread(lambda: pull_request.mergeable_state) - self.logger.debug(f"{self.log_prefix} Mergeable state is {merge_state}") - if merge_state == "unknown": - return + """Label pull request based on merge state. - if merge_state == "behind": - await self.labels_handler._add_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR) - else: - await self.labels_handler._remove_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR) + Simple flow: + 1. Check pull_request.mergeable for conflicts + 2. If has conflicts → add has-conflicts, exit + 3. Else → remove has-conflicts, check Compare API for rebase status - if merge_state == "dirty": - await self.labels_handler._add_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR) - else: - await self.labels_handler._remove_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR) + Uses both GitHub APIs for accurate labeling: + - has-conflicts: pull_request.mergeable == False (true merge conflict detection) + - needs-rebase: Compare API behind_by > 0 or status == "diverged" + + Both labels can coexist - they both reflect the actual PR state. + + Args: + pull_request: The GitHub pull request object to label. + """ + if self.ctx: + self.ctx.start_step("label_merge_state") + + try: + # Get current labels (single API call for optimization) + current_labels = await self.labels_handler.pull_request_labels_names(pull_request=pull_request) + has_conflicts_label_exists = HAS_CONFLICTS_LABEL_STR in current_labels + needs_rebase_label_exists = NEEDS_REBASE_LABEL_STR in current_labels + + # Step 1: Check for conflicts first + mergeable = await asyncio.to_thread(lambda: pull_request.mergeable) + has_conflicts = mergeable is False + + if has_conflicts: + # Has conflicts - add has-conflicts label and exit + self.logger.debug(f"{self.log_prefix} PR has conflicts. {mergeable=}") + + if not has_conflicts_label_exists: + self.logger.debug(f"{self.log_prefix} Adding {HAS_CONFLICTS_LABEL_STR} label") + await self.labels_handler._add_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR) + + if self.ctx: + self.ctx.complete_step("label_merge_state", has_conflicts=True, needs_rebase=False) + return # Exit early - conflicts take precedence + + # Step 2: No conflicts - remove has-conflicts label if present + if has_conflicts_label_exists: + self.logger.debug(f"{self.log_prefix} Removing {HAS_CONFLICTS_LABEL_STR} label") + await self.labels_handler._remove_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR) + + # Step 3: Check if needs rebase via Compare API + base_ref, head_user_login, head_ref = await asyncio.gather( + asyncio.to_thread(lambda: pull_request.base.ref), + asyncio.to_thread(lambda: pull_request.head.user.login), + asyncio.to_thread(lambda: pull_request.head.ref), + ) + head_ref_full = f"{head_user_login}:{head_ref}" + + compare_data = await self._compare_branches(base_ref=base_ref, head_ref_full=head_ref_full) + if compare_data is None: + self.logger.warning(f"{self.log_prefix} Compare API failed, skipping rebase label update") + if self.ctx: + self.ctx.complete_step("label_merge_state", compare_api_failed=True) + return + + behind_by = compare_data.get("behind_by", 0) + status = compare_data.get("status", "") + + needs_rebase = behind_by > 0 or status == "diverged" + + self.logger.debug( + f"{self.log_prefix} Compare API - behind_by: {behind_by}, " + f"status: {status}, needs_rebase: {needs_rebase}" + ) + + # Step 4: Update needs-rebase label + if needs_rebase and not needs_rebase_label_exists: + self.logger.debug(f"{self.log_prefix} Adding {NEEDS_REBASE_LABEL_STR} label") + await self.labels_handler._add_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR) + elif not needs_rebase and needs_rebase_label_exists: + self.logger.debug(f"{self.log_prefix} Removing {NEEDS_REBASE_LABEL_STR} label") + await self.labels_handler._remove_label(pull_request=pull_request, label=NEEDS_REBASE_LABEL_STR) + + if self.ctx: + self.ctx.complete_step("label_merge_state", has_conflicts=False, needs_rebase=needs_rebase) + + except asyncio.CancelledError: + self.logger.debug(f"{self.log_prefix} Label merge state check cancelled") + raise + except Exception as ex: + self.logger.exception(f"{self.log_prefix} Failed to label merge state") + if self.ctx: + self.ctx.fail_step("label_merge_state", ex, traceback.format_exc()) + raise async def _process_verified_for_update_or_new_pull_request(self, pull_request: PullRequest) -> None: if not self.github_webhook.verified_job: diff --git a/webhook_server/tests/test_pull_request_handler.py b/webhook_server/tests/test_pull_request_handler.py index 7222735c..78a15929 100644 --- a/webhook_server/tests/test_pull_request_handler.py +++ b/webhook_server/tests/test_pull_request_handler.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -524,9 +525,31 @@ async def test_remove_labels_when_pull_request_sync( async def test_label_pull_request_by_merge_state_mergeable( self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock ) -> None: + """Test labeling pull request when mergeable and up-to-date.""" mock_pull_request.mergeable = True - mock_pull_request.mergeable_state = "clean" - with patch.object(pull_request_handler.labels_handler, "_remove_label", new=AsyncMock()) as mock_remove_label: + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + + # Mock existing labels - PR currently has both labels that need to be removed + mock_label1 = Mock() + mock_label1.name = HAS_CONFLICTS_LABEL_STR + mock_label2 = Mock() + mock_label2.name = NEEDS_REBASE_LABEL_STR + mock_pull_request.labels = [mock_label1, mock_label2] + + # Mock Compare API response - up-to-date + mock_compare_data = {"behind_by": 0, "status": "ahead"} + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + with ( + patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[HAS_CONFLICTS_LABEL_STR, NEEDS_REBASE_LABEL_STR]), + ), + patch.object(pull_request_handler.labels_handler, "_remove_label", new=AsyncMock()) as mock_remove_label, + ): await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request) assert mock_remove_label.await_count == 2 @@ -536,9 +559,25 @@ async def test_label_pull_request_by_merge_state_needs_rebase( ) -> None: """Test labeling pull request by merge state when needs rebase.""" mock_pull_request.mergeable = True - mock_pull_request.mergeable_state = "behind" + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" - with patch.object(pull_request_handler.labels_handler, "_add_label") as mock_add_label: + # Mock existing labels - PR has no labels currently + mock_pull_request.labels = [] + + # Mock Compare API response - behind + mock_compare_data = {"behind_by": 5, "status": "behind"} + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + with ( + patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ), + patch.object(pull_request_handler.labels_handler, "_add_label") as mock_add_label, + ): await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request) mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=NEEDS_REBASE_LABEL_STR) @@ -546,12 +585,33 @@ async def test_label_pull_request_by_merge_state_needs_rebase( async def test_label_pull_request_by_merge_state_has_conflicts( self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock ) -> None: - """Test labeling pull request by merge state when has conflicts.""" - mock_pull_request.mergeable = False - mock_pull_request.mergeable_state = "dirty" + """Test labeling pull request by merge state when has conflicts. - with patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label: + Uses pull_request.mergeable == False to detect conflicts. + When mergeable is False, ONLY has-conflicts label is set (conflicts take precedence over needs-rebase). + """ + mock_pull_request.mergeable = False # Conflict detected via mergeable + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + + # Mock existing labels - PR has no labels currently + mock_pull_request.labels = [] + + # Mock Compare API response - clean (no rebase needed) + mock_compare_data = {"behind_by": 0, "status": "ahead"} + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + with ( + patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ), + patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label, + ): await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request) + # When mergeable is False, only has-conflicts label is set (conflicts take precedence) mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=HAS_CONFLICTS_LABEL_STR) @pytest.mark.asyncio @@ -1708,16 +1768,134 @@ async def test_set_pull_request_automerge_exception( async def test_label_pull_request_by_merge_state_unknown( self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock ) -> None: - """Test label_pull_request_by_merge_state when unknown.""" - mock_pull_request.mergeable_state = "unknown" + """Test label_pull_request_by_merge_state when mergeable=None. - with patch( - "asyncio.to_thread", side_effect=lambda f, *args, **kwargs: f(*args, **kwargs) if callable(f) else None + When mergeable=None (not yet computed), has_conflicts is False. + If Compare API shows behind_by > 0, needs-rebase label should be added. + """ + mock_pull_request.mergeable = None # Not yet computed by GitHub + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + + # Mock existing labels - PR has no labels currently + mock_pull_request.labels = [] + + # Mock Compare API response - behind by 5 commits + mock_compare_data = {"behind_by": 5, "status": "behind"} + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + with ( + patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ), + patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label, ): await pull_request_handler.label_pull_request_by_merge_state(mock_pull_request) + # Should add needs-rebase label since behind_by > 0 and no conflicts (mergeable=None means no conflict) + mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=NEEDS_REBASE_LABEL_STR) - # Should return early - pull_request_handler.labels_handler._add_label.assert_not_called() + @pytest.mark.asyncio + async def test_label_pull_request_by_merge_state_diverged( + self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock + ) -> None: + """Test labeling pull request when diverged from base. + + Uses Compare API status='diverged' to detect needs-rebase. + When status='diverged', only needs-rebase label is set (no conflicts via mergeable). + """ + mock_pull_request.mergeable = True # No conflicts + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + + # Mock existing labels - PR has no labels currently + mock_pull_request.labels = [] + + # Mock Compare API response - diverged (needs rebase) + mock_compare_data = {"behind_by": 3, "status": "diverged"} + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + with ( + patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ), + patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label, + ): + await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request) + # When diverged and no conflicts, only needs-rebase label is set + mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=NEEDS_REBASE_LABEL_STR) + + @pytest.mark.asyncio + async def test_label_pull_request_by_merge_state_diverged_zero_behind( + self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock + ) -> None: + """Test diverged status with zero behind_by (edge case). + + When status is 'diverged' but behind_by is 0, needs_rebase should still be True + because diverged means the branch has both commits ahead AND commits that differ + from the base branch. + """ + mock_pull_request.mergeable = True # No conflicts + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + mock_pull_request.labels = [] # No existing labels + + # Edge case: diverged but behind_by=0 + mock_compare_data = {"behind_by": 0, "status": "diverged"} + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + with ( + patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ), + patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label, + ): + await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request) + + # Should add needs-rebase because status="diverged" (even with behind_by=0) + mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=NEEDS_REBASE_LABEL_STR) + + @pytest.mark.asyncio + async def test_label_pull_request_by_merge_state_behind_and_conflicts( + self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock + ) -> None: + """Test labeling pull request when behind and has conflicts. + + Uses pull_request.mergeable == False to detect conflicts. + Uses Compare API status='diverged' to detect needs-rebase. + When both exist, ONLY has-conflicts label is set (conflicts take precedence over needs-rebase). + """ + mock_pull_request.mergeable = False # Conflict detected via mergeable + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + + # Mock existing labels - PR has no labels currently + mock_pull_request.labels = [] + + # Mock Compare API response - diverged (needs rebase) + mergeable=False (conflicts) + mock_compare_data = {"behind_by": 2, "status": "diverged"} + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + with ( + patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ), + patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label, + ): + await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request) + # When mergeable is False (conflicts), only has-conflicts label is set (conflicts take precedence) + mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=HAS_CONFLICTS_LABEL_STR) @pytest.mark.asyncio async def test_delete_registry_tag_via_regctl_failure( @@ -1752,3 +1930,73 @@ async def test_delete_registry_tag_via_regctl_failure( pull_request_handler.logger.error.assert_called_with( "[TEST] Failed to delete tag: tag. OUT:Delete failed. ERR:Error" ) + + @pytest.mark.asyncio + async def test_label_pull_request_by_merge_state_compare_api_failure( + self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock + ) -> None: + """Test handling of Compare API failure - should log warning and return without updating labels.""" + mock_pull_request.mergeable = True # No conflicts (not used anymore) + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + + # Mock existing labels + mock_pull_request.labels = [] + + # Mock Compare API to raise GithubException + pull_request_handler.repository._requester.requestJsonAndCheck = Mock( + side_effect=GithubException(500, {"message": "API error"}, None) + ) + + # Reset mocks + pull_request_handler.labels_handler._add_label.reset_mock() + pull_request_handler.labels_handler._remove_label.reset_mock() + + with patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ): + await pull_request_handler.label_pull_request_by_merge_state(mock_pull_request) + + # With new simplified logic: if Compare API fails, no label updates at all + pull_request_handler.labels_handler._remove_label.assert_not_called() + pull_request_handler.labels_handler._add_label.assert_not_called() + + @pytest.mark.asyncio + async def test_label_pull_request_by_merge_state_incomplete_compare_data( + self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock + ) -> None: + """Test handling of incomplete Compare API response. + + With combined logic, pull_request.mergeable is used for conflicts. + If Compare API has missing behind_by but mergeable is False, conflict label is still added. + """ + mock_pull_request.mergeable = False # Conflict detected via mergeable + mock_pull_request.base.ref = "main" + mock_pull_request.head.user.login = "test-user" + mock_pull_request.head.ref = "feature-branch" + + # Mock existing labels - PR has no labels currently + mock_pull_request.labels = [] + + # Mock Compare API with missing behind_by key - status is 'behind' (not diverged) + mock_compare_data: dict[str, Any] = {"status": "behind"} # Missing behind_by + pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data)) + + # Reset mocks + pull_request_handler.labels_handler._add_label.reset_mock() + pull_request_handler.labels_handler._remove_label.reset_mock() + + with patch.object( + pull_request_handler.labels_handler, + "pull_request_labels_names", + new=AsyncMock(return_value=[]), + ): + await pull_request_handler.label_pull_request_by_merge_state(mock_pull_request) + + # mergeable is False, so conflict label should be added + pull_request_handler.labels_handler._add_label.assert_called_once_with( + pull_request=mock_pull_request, label=HAS_CONFLICTS_LABEL_STR + )