diff --git a/examples/.github-webhook-server.yaml b/examples/.github-webhook-server.yaml index 0a5f7238..6bccbab1 100644 --- a/examples/.github-webhook-server.yaml +++ b/examples/.github-webhook-server.yaml @@ -121,6 +121,8 @@ minimum-lgtm: 2 # Issue creation for new pull requests create-issue-for-new-pr: true # Create tracking issues for new PRs +cherry-pick-assign-to-pr-author: true # Assign cherry-pick PRs to the original PR author (default: true) + # Custom PR size labels for this repository (overrides global configuration) # Define custom categories based on total lines changed (additions + deletions) # threshold: positive integer or 'inf' for unbounded largest category diff --git a/examples/config.yaml b/examples/config.yaml index f71ab7c9..d22e08a4 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -33,6 +33,8 @@ auto-verify-cherry-picked-prs: true # Default: true - automatically verify cher create-issue-for-new-pr: true # Global default: create tracking issues for new PRs +cherry-pick-assign-to-pr-author: true # Default: true - assign cherry-pick PRs to the original PR author + # Commands allowed on draft PRs (optional) # If not set: commands are blocked on draft PRs (default behavior) # If empty list []: all commands allowed on draft PRs diff --git a/webhook_server/config/schema.yaml b/webhook_server/config/schema.yaml index 0f640de2..9f70bcc3 100644 --- a/webhook_server/config/schema.yaml +++ b/webhook_server/config/schema.yaml @@ -83,6 +83,10 @@ properties: type: boolean description: Create a tracking issue for new pull requests (global default) default: true + cherry-pick-assign-to-pr-author: + type: boolean + description: Assign cherry-pick PRs to the original PR author (default true) + default: true allow-commands-on-draft-prs: type: array items: @@ -388,6 +392,10 @@ properties: type: boolean description: Create a tracking issue for new pull requests default: true + cherry-pick-assign-to-pr-author: + type: boolean + description: Assign cherry-pick PRs to the original PR author (overrides global setting) + default: true allow-commands-on-draft-prs: type: array items: diff --git a/webhook_server/libs/github_api.py b/webhook_server/libs/github_api.py index 58daf24d..23704bd3 100644 --- a/webhook_server/libs/github_api.py +++ b/webhook_server/libs/github_api.py @@ -735,6 +735,17 @@ def _repo_data_from_config(self, repository_config: dict[str, Any]) -> None: value="create-issue-for-new-pr", return_on_none=global_create_issue_for_new_pr, extra_dict=repository_config ) + # Load global cherry_pick_assign_to_pr_author setting as fallback + global_cherry_pick_assign: bool = self.config.get_value( + value="cherry-pick-assign-to-pr-author", return_on_none=True + ) + # Repository-specific setting overrides global setting + self.cherry_pick_assign_to_pr_author: bool = self.config.get_value( + value="cherry-pick-assign-to-pr-author", + return_on_none=global_cherry_pick_assign, + extra_dict=repository_config, + ) + # Read required_conversation_resolution from branch-protection config _bp_key = "required_conversation_resolution" _bp_raw_default = DEFAULT_BRANCH_PROTECTION[_bp_key] diff --git a/webhook_server/libs/handlers/issue_comment_handler.py b/webhook_server/libs/handlers/issue_comment_handler.py index 940fd417..f473bad7 100644 --- a/webhook_server/libs/handlers/issue_comment_handler.py +++ b/webhook_server/libs/handlers/issue_comment_handler.py @@ -456,7 +456,7 @@ async def process_cherry_pick_command( await self.runner_handler.cherry_pick( pull_request=pull_request, target_branch=_exits_target_branch, - reviewed_user=reviewed_user, + assign_to_pr_owner=self.github_webhook.cherry_pick_assign_to_pr_author, ) for _cp_label in cp_labels: diff --git a/webhook_server/libs/handlers/pull_request_handler.py b/webhook_server/libs/handlers/pull_request_handler.py index 512a504a..c181d8e7 100644 --- a/webhook_server/libs/handlers/pull_request_handler.py +++ b/webhook_server/libs/handlers/pull_request_handler.py @@ -152,11 +152,18 @@ async def process_pull_request_webhook_data(self, pull_request: PullRequest) -> self.logger.info(f"{self.log_prefix} PR is merged") labels = await asyncio.to_thread(lambda: list(pull_request.labels)) - for _label in labels: - _label_name = _label.name - if _label_name.startswith(CHERRY_PICK_LABEL_PREFIX): + if cherry_pick_labels := [ + _label for _label in labels if _label.name.startswith(CHERRY_PICK_LABEL_PREFIX) + ]: + for _label in cherry_pick_labels: + target_branch = _label.name.removeprefix(CHERRY_PICK_LABEL_PREFIX) + if not target_branch: + self.logger.warning(f"{self.log_prefix} Skipping invalid cherry-pick label: {_label.name}") + continue await self.runner_handler.cherry_pick( - pull_request=pull_request, target_branch=_label_name.replace(CHERRY_PICK_LABEL_PREFIX, "") + pull_request=pull_request, + target_branch=target_branch, + assign_to_pr_owner=self.github_webhook.cherry_pick_assign_to_pr_author, ) await self.runner_handler.run_build_container( diff --git a/webhook_server/libs/handlers/runner_handler.py b/webhook_server/libs/handlers/runner_handler.py index b0adb1fe..ddb04bff 100644 --- a/webhook_server/libs/handlers/runner_handler.py +++ b/webhook_server/libs/handlers/runner_handler.py @@ -536,9 +536,18 @@ async def run_custom_check( async def is_branch_exists(self, branch: str) -> Branch: return await asyncio.to_thread(self.repository.get_branch, branch) - async def cherry_pick(self, pull_request: PullRequest, target_branch: str, reviewed_user: str = "") -> None: - requested_by = reviewed_user or "by target-branch label" - self.logger.info(f"{self.log_prefix} Cherry-pick requested by user: {requested_by}") + async def cherry_pick( + self, + pull_request: PullRequest, + target_branch: str, + assign_to_pr_owner: bool = True, + ) -> None: + pr_author = await asyncio.to_thread(lambda: pull_request.user.login) + source_branch = await asyncio.to_thread(lambda: pull_request.base.ref) + + self.logger.info( + f"{self.log_prefix} Cherry-pick from {source_branch} to {target_branch}, PR owner: {pr_author}" + ) new_branch_name = f"{CHERRY_PICKED_LABEL}-{pull_request.head.ref}-{shortuuid.uuid()[:5]}" if not await self.is_branch_exists(branch=target_branch): @@ -556,6 +565,7 @@ async def cherry_pick(self, pull_request: PullRequest, target_branch: str, revie async with self._checkout_worktree(pull_request=pull_request) as (success, worktree_path, out, err): git_cmd = f"git --work-tree={worktree_path} --git-dir={worktree_path}/.git" hub_cmd = f"GITHUB_TOKEN={github_token} hub --work-tree={worktree_path} --git-dir={worktree_path}/.git" + assignee_flag = f" -a {shlex.quote(pr_author)}" if assign_to_pr_owner else "" commands: list[str] = [ f"{git_cmd} checkout {target_branch}", f"{git_cmd} pull origin {target_branch}", @@ -563,10 +573,10 @@ async def cherry_pick(self, pull_request: PullRequest, target_branch: str, revie f"{git_cmd} cherry-pick {commit_hash}", f"{git_cmd} push origin {new_branch_name}", f'bash -c "{hub_cmd} pull-request -b {target_branch} ' - f"-h {new_branch_name} -l {CHERRY_PICKED_LABEL} " + f"-h {new_branch_name} -l {CHERRY_PICKED_LABEL} {assignee_flag} " f"-m '{CHERRY_PICKED_LABEL}: [{target_branch}] " - f"{commit_msg_striped}' -m 'cherry-pick {pull_request_url} " - f"into {target_branch}' -m 'requested-by {requested_by}'\"", + f"{commit_msg_striped}' -m 'Cherry-pick from `{source_branch}` branch, " + f"original PR: {pull_request_url}, PR owner: {pr_author}'\"", ] output: CheckRunOutput = { diff --git a/webhook_server/tests/test_config_schema.py b/webhook_server/tests/test_config_schema.py index 9e0e95b8..046e1204 100644 --- a/webhook_server/tests/test_config_schema.py +++ b/webhook_server/tests/test_config_schema.py @@ -41,6 +41,7 @@ def valid_full_config(self) -> dict[str, Any]: "docker": {"username": "dockeruser", "password": "dockerpass"}, # pragma: allowlist secret "default-status-checks": ["WIP", "build"], "auto-verified-and-merged-users": ["bot[bot]"], + "cherry-pick-assign-to-pr-author": True, "branch-protection": { "strict": True, "require_code_owner_reviews": True, diff --git a/webhook_server/tests/test_issue_comment_handler.py b/webhook_server/tests/test_issue_comment_handler.py index c31fcd74..0ff45c5b 100644 --- a/webhook_server/tests/test_issue_comment_handler.py +++ b/webhook_server/tests/test_issue_comment_handler.py @@ -45,6 +45,7 @@ def mock_github_webhook(self) -> Mock: mock_webhook.current_pull_request_supported_retest = [TOX_STR, "pre-commit"] mock_webhook.ctx = None mock_webhook.custom_check_runs = [] + mock_webhook.cherry_pick_assign_to_pr_author = True # Mock config for draft PR command filtering mock_webhook.config = Mock() mock_webhook.config.get_value = Mock(return_value=None) @@ -854,7 +855,7 @@ async def test_process_cherry_pick_command_merged_pr(self, issue_comment_handler mock_cherry_pick.assert_called_once_with( pull_request=mock_pull_request, target_branch="branch1", - reviewed_user="test-user", + assign_to_pr_owner=True, ) mock_add_label.assert_called_once_with( pull_request=mock_pull_request, @@ -900,17 +901,17 @@ async def test_process_cherry_pick_command_merged_pr_multiple_branches( mock_cherry_pick.assert_any_call( pull_request=mock_pull_request, target_branch="branch1", - reviewed_user="test-user", + assign_to_pr_owner=True, ) mock_cherry_pick.assert_any_call( pull_request=mock_pull_request, target_branch="branch2", - reviewed_user="test-user", + assign_to_pr_owner=True, ) mock_cherry_pick.assert_any_call( pull_request=mock_pull_request, target_branch="branch3", - reviewed_user="test-user", + assign_to_pr_owner=True, ) # Verify labels were added exactly once for each branch (not duplicated) @@ -919,6 +920,36 @@ async def test_process_cherry_pick_command_merged_pr_multiple_branches( mock_add_label.assert_any_call(pull_request=mock_pull_request, label="cherry-pick-branch2") mock_add_label.assert_any_call(pull_request=mock_pull_request, label="cherry-pick-branch3") + @pytest.mark.asyncio + async def test_process_cherry_pick_command_merged_pr_assign_disabled( + self, issue_comment_handler: IssueCommentHandler + ) -> None: + """Test cherry-pick on merged PR passes assign_to_pr_owner=False when config disabled.""" + issue_comment_handler.github_webhook.cherry_pick_assign_to_pr_author = False + mock_pull_request = Mock() + with patch.object(mock_pull_request, "is_merged", new=Mock(return_value=True)): + with patch.object(issue_comment_handler.repository, "get_branch"): + with patch.object( + issue_comment_handler.runner_handler, + "cherry_pick", + new_callable=AsyncMock, + ) as mock_cherry_pick: + with patch.object( + issue_comment_handler.labels_handler, + "_add_label", + new_callable=AsyncMock, + ): + await issue_comment_handler.process_cherry_pick_command( + pull_request=mock_pull_request, + command_args="branch1", + reviewed_user="test-user", + ) + mock_cherry_pick.assert_called_once_with( + pull_request=mock_pull_request, + target_branch="branch1", + assign_to_pr_owner=False, + ) + @pytest.mark.asyncio async def test_process_retest_command_no_target_tests(self, issue_comment_handler: IssueCommentHandler) -> None: """Test processing retest command with no target tests.""" diff --git a/webhook_server/tests/test_pull_request_handler.py b/webhook_server/tests/test_pull_request_handler.py index 8be5eef1..50678600 100644 --- a/webhook_server/tests/test_pull_request_handler.py +++ b/webhook_server/tests/test_pull_request_handler.py @@ -85,6 +85,7 @@ def mock_github_webhook(self) -> Mock: mock_webhook.pypi = False mock_webhook.token = TEST_GITHUB_TOKEN mock_webhook.auto_verify_cherry_picked_prs = True + mock_webhook.cherry_pick_assign_to_pr_author = True mock_webhook.last_commit = Mock() mock_webhook.ctx = None mock_webhook.enabled_labels = None # Default: all labels enabled @@ -294,7 +295,9 @@ async def test_process_pull_request_webhook_data_closed_action_merged( ) mock_delete_tag.assert_called_once_with(pull_request=mock_pull_request) mock_cherry_pick.assert_called_once_with( - pull_request=mock_pull_request, target_branch="branch1" + pull_request=mock_pull_request, + target_branch="branch1", + assign_to_pr_owner=True, ) mock_build.assert_called_once_with( push=True, @@ -304,6 +307,75 @@ async def test_process_pull_request_webhook_data_closed_action_merged( ) mock_label_all.assert_called_once() + @pytest.mark.asyncio + async def test_process_pull_request_cherry_pick_label_multiple_branches( + self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock + ) -> None: + """Test cherry-pick is triggered for each cherry-pick label on merge.""" + pull_request_handler.hook_data["action"] = "closed" + pull_request_handler.hook_data["pull_request"]["merged"] = True + + mock_label1 = Mock() + mock_label1.name = f"{CHERRY_PICK_LABEL_PREFIX}branch1" + mock_label2 = Mock() + mock_label2.name = f"{CHERRY_PICK_LABEL_PREFIX}branch2" + mock_pull_request.labels = [mock_label1, mock_label2] + + with patch.object(pull_request_handler, "close_issue_for_merged_or_closed_pr"): + with patch.object(pull_request_handler, "delete_remote_tag_for_merged_or_closed_pr"): + with patch.object( + pull_request_handler.runner_handler, "cherry_pick", new_callable=AsyncMock + ) as mock_cherry_pick: + with patch.object( + pull_request_handler.runner_handler, "run_build_container", new_callable=AsyncMock + ): + with patch.object( + pull_request_handler, "label_all_opened_pull_requests_merge_state_after_merged" + ): + await pull_request_handler.process_pull_request_webhook_data(mock_pull_request) + assert mock_cherry_pick.call_count == 2 + mock_cherry_pick.assert_any_call( + pull_request=mock_pull_request, + target_branch="branch1", + assign_to_pr_owner=True, + ) + mock_cherry_pick.assert_any_call( + pull_request=mock_pull_request, + target_branch="branch2", + assign_to_pr_owner=True, + ) + + @pytest.mark.asyncio + async def test_process_pull_request_cherry_pick_label_assign_disabled( + self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock + ) -> None: + """Test cherry-pick passes assign_to_pr_owner=False when config disabled.""" + pull_request_handler.hook_data["action"] = "closed" + pull_request_handler.hook_data["pull_request"]["merged"] = True + pull_request_handler.github_webhook.cherry_pick_assign_to_pr_author = False + + mock_label = Mock() + mock_label.name = f"{CHERRY_PICK_LABEL_PREFIX}target-branch" + mock_pull_request.labels = [mock_label] + + with patch.object(pull_request_handler, "close_issue_for_merged_or_closed_pr"): + with patch.object(pull_request_handler, "delete_remote_tag_for_merged_or_closed_pr"): + with patch.object( + pull_request_handler.runner_handler, "cherry_pick", new_callable=AsyncMock + ) as mock_cherry_pick: + with patch.object( + pull_request_handler.runner_handler, "run_build_container", new_callable=AsyncMock + ): + with patch.object( + pull_request_handler, "label_all_opened_pull_requests_merge_state_after_merged" + ): + await pull_request_handler.process_pull_request_webhook_data(mock_pull_request) + mock_cherry_pick.assert_called_once_with( + pull_request=mock_pull_request, + target_branch="target-branch", + assign_to_pr_owner=False, + ) + @pytest.mark.asyncio async def test_process_pull_request_webhook_data_labeled_action( self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock diff --git a/webhook_server/tests/test_runner_handler.py b/webhook_server/tests/test_runner_handler.py index 0ec92acc..0c3e66f3 100644 --- a/webhook_server/tests/test_runner_handler.py +++ b/webhook_server/tests/test_runner_handler.py @@ -1,4 +1,6 @@ -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator +from contextlib import asynccontextmanager +from dataclasses import dataclass from unittest.mock import AsyncMock, Mock, patch import pytest @@ -13,6 +15,15 @@ ) +@dataclass +class CherryPickMocks: + set_progress: Mock + set_success: Mock + run_cmd: Mock + comment: Mock + to_thread: Mock + + class TestRunnerHandler: """Test suite for RunnerHandler class.""" @@ -68,6 +79,8 @@ def mock_pull_request(self) -> Mock: mock_pr.head.ref = "feature-branch" mock_pr.merge_commit_sha = "abc123" mock_pr.html_url = "https://github.com/test/repo/pull/123" + mock_pr.user = Mock() + mock_pr.user.login = "test-pr-author" mock_pr.create_issue_comment = Mock() return mock_pr @@ -918,6 +931,72 @@ async def test_cherry_pick_manual_needed(self, runner_handler, mock_pull_request mock_set_failure.assert_called_once() mock_comment.assert_called_once() + @staticmethod + @asynccontextmanager + async def cherry_pick_setup( + runner_handler: RunnerHandler, + mock_pull_request: Mock, + ) -> AsyncGenerator[CherryPickMocks]: + """Common setup for cherry-pick tests.""" + runner_handler.github_webhook.pypi = {"token": "dummy"} + with patch.object(runner_handler, "is_branch_exists", new=AsyncMock(return_value=Mock())): + with patch.object(runner_handler.check_run_handler, "set_check_in_progress") as mock_set_progress: + with patch.object(runner_handler.check_run_handler, "set_check_success") as mock_set_success: + with patch.object(runner_handler, "_checkout_worktree") as mock_checkout: + mock_checkout.return_value = AsyncMock() + mock_checkout.return_value.__aenter__ = AsyncMock( + return_value=(True, "/tmp/worktree-path", "", "") + ) + mock_checkout.return_value.__aexit__ = AsyncMock(return_value=None) + with patch( + "webhook_server.libs.handlers.runner_handler.run_command", + new=AsyncMock(return_value=(True, "success", "")), + ) as mock_run_cmd: + with patch.object(mock_pull_request, "create_issue_comment", new=Mock()) as mock_comment: + with patch( + "asyncio.to_thread", + new=AsyncMock(side_effect=lambda fn, *a, **kw: fn(*a, **kw) if a or kw else fn()), + ) as mock_to_thread: + yield CherryPickMocks( + set_progress=mock_set_progress, + set_success=mock_set_success, + run_cmd=mock_run_cmd, + comment=mock_comment, + to_thread=mock_to_thread, + ) + + @pytest.mark.asyncio + async def test_cherry_pick_assigns_pr_author(self, runner_handler: RunnerHandler, mock_pull_request: Mock) -> None: + """Test cherry_pick assigns to PR author, not the cherry-pick requester.""" + async with self.cherry_pick_setup(runner_handler, mock_pull_request) as mocks: + await runner_handler.cherry_pick(mock_pull_request, "main") + mocks.set_progress.assert_called_once() + mocks.set_success.assert_called_once() + mocks.comment.assert_called_once() + assert mocks.to_thread.call_count == 3 + last_cmd = mocks.run_cmd.call_args_list[-1] + hub_command = last_cmd.kwargs.get("command", last_cmd.args[0] if last_cmd.args else "") + assert "-a 'test-pr-author'" in hub_command or "-a test-pr-author" in hub_command + + @pytest.mark.asyncio + async def test_cherry_pick_requested_by_uses_pr_owner( + self, runner_handler: RunnerHandler, mock_pull_request: Mock + ) -> None: + """Test cherry_pick PR description includes source branch and PR owner.""" + async with self.cherry_pick_setup(runner_handler, mock_pull_request) as mocks: + await runner_handler.cherry_pick(mock_pull_request, "main") + mocks.set_progress.assert_called_once() + mocks.set_success.assert_called_once() + mocks.comment.assert_called_once() + last_cmd = mocks.run_cmd.call_args_list[-1] + hub_command = last_cmd.kwargs.get("command", last_cmd.args[0] if last_cmd.args else "") + expected_msg = ( + f"Cherry-pick from `main` branch, original PR: {mock_pull_request.html_url}, PR owner: test-pr-author" + ) + assert expected_msg in hub_command + assert "-a 'test-pr-author'" in hub_command or "-a test-pr-author" in hub_command + assert mocks.to_thread.call_count == 3 + @pytest.mark.asyncio async def test_checkout_worktree_branch_already_checked_out( self, runner_handler: RunnerHandler, mock_pull_request: Mock diff --git a/webhook_server/tests/test_schema_validator.py b/webhook_server/tests/test_schema_validator.py index c00b578a..1784e9bf 100644 --- a/webhook_server/tests/test_schema_validator.py +++ b/webhook_server/tests/test_schema_validator.py @@ -73,6 +73,7 @@ def _validate_root_fields(self, config: dict[str, Any]) -> None: "disable-ssl-warnings", "mask-sensitive-data", "auto-verify-cherry-picked-prs", + "cherry-pick-assign-to-pr-author", ] for field in boolean_fields: if field in config and not isinstance(config[field], bool): @@ -171,6 +172,7 @@ def _validate_single_repository(self, repo_name: str, repo_config: Any) -> None: "pre-commit", "mask-sensitive-data", "auto-verify-cherry-picked-prs", + "cherry-pick-assign-to-pr-author", ] for field in boolean_fields: if field in repo_config and not isinstance(repo_config[field], bool):