From 6bf883487d62a7cde9338c5f82d233f3fbdb4cca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Ma=C5=88=C3=A1k?= Date: Fri, 24 Apr 2026 12:27:34 +0200 Subject: [PATCH 1/2] whitespace --- rebasebot/bot.py | 55 ++++++++++++++++++++++++++--------- rebasebot/cli.py | 2 -- rebasebot/github.py | 4 ++- rebasebot/lifecycle_hooks.py | 16 +++++++--- tests/conftest.py | 4 --- tests/test_bot.py | 37 +++++++++++++++++++---- tests/test_cli.py | 20 ++++++++++--- tests/test_conflict_policy.py | 3 -- tests/test_rebases.py | 21 ++++++++++--- 9 files changed, 122 insertions(+), 40 deletions(-) diff --git a/rebasebot/bot.py b/rebasebot/bot.py index edb7bab..06ad780 100755 --- a/rebasebot/bot.py +++ b/rebasebot/bot.py @@ -1,6 +1,5 @@ #!/usr/bin/python # pylint: disable=too-many-lines - # Copyright 2022 Red Hat, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -48,8 +47,6 @@ class PullRequestUpdateException(Exception): logging.basicConfig(format="%(levelname)s - %(message)s", stream=sys.stdout, level=logging.INFO) - - MERGE_TMP_BRANCH = "merge-tmp" _COMMIT_LOG_FORMAT = "--pretty=format:%H || %s || %aE" _MERGE_COMMIT_PARENT_COUNT = 2 @@ -105,7 +102,11 @@ def _is_pr_merged(pr_number: int, source_repo: Repository, gitwd: git.Repo, sour def _add_to_rebase( - commit_message: str, source_repo: Repository, tag_policy: str, gitwd: git.Repo, source_branch: str + commit_message: str, + source_repo: Repository, + tag_policy: str, + gitwd: git.Repo, + source_branch: str, ) -> bool: valid_tag_policy = ["soft", "strict", "none"] if tag_policy not in valid_tag_policy: @@ -186,7 +187,11 @@ def _identify_downstream_commits(gitwd: git.Repo, source: GitHubBranch, dest: Gi # ancestry_path_merges are merge commits on ancestry path from merge base to destination branch ancestry_path_merges = gitwd.git.log( - _COMMIT_LOG_FORMAT, "--ancestry-path", "-r", "--merges", f"{merge_base}..dest/{dest.branch}" + _COMMIT_LOG_FORMAT, + "--ancestry-path", + "-r", + "--merges", + f"{merge_base}..dest/{dest.branch}", ).splitlines() val = "\n".join(ancestry_path_merges) @@ -216,7 +221,12 @@ def _identify_downstream_commits(gitwd: git.Repo, source: GitHubBranch, dest: Gi # Fetch all downstream (non-merge) commits with full formatting. all_downstream_lines = gitwd.git.log( - "--reverse", "--topo-order", _COMMIT_LOG_FORMAT, "--no-merges", *cutoff_commits, f"dest/{dest.branch}" + "--reverse", + "--topo-order", + _COMMIT_LOG_FORMAT, + "--no-merges", + *cutoff_commits, + f"dest/{dest.branch}", ).splitlines() downstream_shas = {line.split(" || ", 1)[0].strip() for line in all_downstream_lines if line.strip()} @@ -241,7 +251,11 @@ def _identify_downstream_commits(gitwd: git.Repo, source: GitHubBranch, dest: Gi # ancestor-or-equal to the rebase branch containing the synthetic # rebase merge commit. first_parent_merges = gitwd.git.rev_list( - "--reverse", "--first-parent", "--merges", *cutoff_commits, f"dest/{dest.branch}" + "--reverse", + "--first-parent", + "--merges", + *cutoff_commits, + f"dest/{dest.branch}", ).splitlines() rebase_pr_merge = None @@ -253,7 +267,8 @@ def _identify_downstream_commits(gitwd: git.Repo, source: GitHubBranch, dest: Gi # So the rebase commit must be reachable from parent[1], but not # yet reachable from parent[0]. if gitwd.is_ancestor(last_rebase_merge_commit, commit.parents[1]) and not gitwd.is_ancestor( - last_rebase_merge_commit, commit.parents[0] + last_rebase_merge_commit, + commit.parents[0], ): rebase_pr_merge = commit break @@ -389,8 +404,12 @@ def _check_upstream_content_loss(gitwd: git.Repo, source_branch: str, only_files return results -def _safe_cherry_pick( - gitwd: git.Repo, sha: str, source_branch: str, conflict_policy: str, commit_description: str +def _safe_cherry_pick( # hi + gitwd: git.Repo, + sha: str, + source_branch: str, + conflict_policy: str, + commit_description: str, ) -> None: """ Cherry-pick a commit with conflict detection based on conflict_policy. @@ -423,7 +442,9 @@ def _safe_cherry_pick( for filename, lost_lines in lost_content: logging.warning( - "Upstream content may have been dropped from '%s' by cherry-pick of: %s", filename, commit_description + "Upstream content may have been dropped from '%s' by cherry-pick of: %s", + filename, + commit_description, ) for line in lost_lines[:_LOST_LINE_LOG_LIMIT]: logging.warning(" lost line: %s", line.strip()) @@ -626,7 +647,10 @@ def _resolve_rebase_conflicts(gitwd: git.Repo) -> bool: def _cherrypick_art_pull_request( - gitwd: git.Repo, dest_repo: Repository, dest: GitHubBranch, conflict_policy: str = "auto" + gitwd: git.Repo, + dest_repo: Repository, + dest: GitHubBranch, + conflict_policy: str = "auto", # hi ) -> None: """ Looks at the destination repository and if there is an open ART pull request @@ -911,7 +935,12 @@ def _update_pr_title(gitwd: git.Repo, pull_req: ShortPullRequest, source: GitHub def _report_result( # pylint: disable=R0917 - needs_rebase: bool, pr_required: bool, pr_available: bool, pr_url: str, dest_url: str, slack_webhook: str + needs_rebase: bool, + pr_required: bool, + pr_available: bool, + pr_url: str, + dest_url: str, + slack_webhook: str, ) -> None: """Reports the result of sucessful rebasebot run to slack and log.""" message = None diff --git a/rebasebot/cli.py b/rebasebot/cli.py index 00be4dd..efbd3ff 100755 --- a/rebasebot/cli.py +++ b/rebasebot/cli.py @@ -1,5 +1,4 @@ #!/usr/bin/python - # Copyright 2022 Red Hat, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -13,7 +12,6 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - """This module parses CLI arguments for the Rebase Bot.""" import argparse diff --git a/rebasebot/github.py b/rebasebot/github.py index 595f70b..ddba5c2 100644 --- a/rebasebot/github.py +++ b/rebasebot/github.py @@ -126,7 +126,9 @@ def __init__( self._app_credentials = GitHubAppCredentials(app_id=app_id, app_key=app_key, github_branch=dest_branch) self._cloner_app_credentials = GitHubAppCredentials( - app_id=cloner_id, app_key=cloner_key, github_branch=rebase_branch + app_id=cloner_id, + app_key=cloner_key, + github_branch=rebase_branch, ) def get_app_token(self) -> str: diff --git a/rebasebot/lifecycle_hooks.py b/rebasebot/lifecycle_hooks.py index 0d89ca1..ff1d54b 100644 --- a/rebasebot/lifecycle_hooks.py +++ b/rebasebot/lifecycle_hooks.py @@ -11,7 +11,6 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - """This module manages user provided scripts that are executed during the rebase process.""" import logging @@ -126,7 +125,14 @@ def _fetch_from_remote_git( raise ValueError(f"Failed to retrieve script from git reference {git_path}") from e def _fetch_from_github_api( - self, *, github, organization: str, name: str, git_repo_path_to_script: str, branch: str, script_file_path: str + self, + *, + github, + organization: str, + name: str, + git_repo_path_to_script: str, + branch: str, + script_file_path: str, ): """Fetches script from GitHub API.""" try: @@ -163,7 +169,8 @@ def fetch_script(self, temp_hook_dir: str, gitwd: git.Repo = None, github: Githu return remote_git_pattern_match = re.match( - "^git:(https://([^/]+)/([^/]+)/([^/]+))/([^:]+?):(.*)$", self.script_location + "^git:(https://([^/]+)/([^/]+)/([^/]+))/([^:]+?):(.*)$", + self.script_location, ) local_git_pattern_match = re.match("^git:([^:]+):([^:]+)$", self.script_location) @@ -238,7 +245,8 @@ def __call__(self, cwd: str = None) -> LifecycleHookScriptResult: def _fetch_file_from_github(github, organization, name, branch, git_repo_path_to_script) -> Contents: return github.github_cloner_app.repository(owner=organization, repository=name).file_contents( - git_repo_path_to_script, ref=branch + git_repo_path_to_script, + ref=branch, ) diff --git a/tests/conftest.py b/tests/conftest.py index 6d86ed3..b9819b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,9 +28,7 @@ from rebasebot.github import GithubAppProvider, GitHubBranch T = TypeVar("T") - YieldFixture = Generator[T, None, None] - _GO_CODE = """ package main import ( @@ -42,12 +40,10 @@ return } """ - _ANOTHER_GO_CODE = """ package main func foo() {} """ - _GO_CODE_FILENAME = "test.go" diff --git a/tests/test_bot.py b/tests/test_bot.py index 0e9982a..7689de2 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -151,12 +151,27 @@ class TestCommitMessageTags: (False, "NO TAG: something", "strict", False), (False, "fooo fooo fooo", "strict", False), # With invalid tag policy - (False, "NO TAG: : something", "asdkjqwe", Exception("Unknown tag policy: asdkjqwe")), + ( + False, + "NO TAG: : something", + "asdkjqwe", + Exception("Unknown tag policy: asdkjqwe"), + ), (False, "NO TAG: something", "123123", Exception("Unknown tag policy: 123123")), (False, "fooo fooo fooo", "fufufu", Exception("Unknown tag policy: fufufu")), # Unknown commit tag - (False, "UPSTREAM: : something", "strict", Exception("Unknown commit message tag: ")), - (False, "UPSTREAM: commit message", "strict", Exception("Unknown commit message tag: commit message")), + ( + False, + "UPSTREAM: : something", + "strict", + Exception("Unknown commit message tag: "), + ), + ( + False, + "UPSTREAM: commit message", + "strict", + Exception("Unknown commit message tag: commit message"), + ), ), ) @patch("rebasebot.bot._is_pr_merged") @@ -290,7 +305,13 @@ class TestReportResult: "I updated existing rebase PR: https://github.com/user/repo/pull/456", ), # Rebase performed but no changes between rebase and dest (no PR needed) - (True, False, False, None, f"Destination repo {dest_url} already contains the latest changes"), + ( + True, + False, + False, + None, + f"Destination repo {dest_url} already contains the latest changes", + ), # Cases when needs_rebase is False ( False, @@ -299,7 +320,13 @@ class TestReportResult: "https://github.com/user/repo/pull/100", "PR https://github.com/user/repo/pull/100 already contains the latest changes", ), - (False, False, False, "", f"Destination repo {dest_url} already contains the latest changes"), + ( + False, + False, + False, + "", + f"Destination repo {dest_url} already contains the latest changes", + ), # Cases when hooks made changes ( False, diff --git a/tests/test_cli.py b/tests/test_cli.py index 562acf2..b90d57c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -73,16 +73,25 @@ class TestCliArgParser: ( "https://github.com/kubernetes/autoscaler:master", GitHubBranch( - url="https://github.com/kubernetes/autoscaler", ns="kubernetes", name="autoscaler", branch="master" + url="https://github.com/kubernetes/autoscaler", + ns="kubernetes", + name="autoscaler", + branch="master", ), ), ( "kubernetes/autoscaler:master", GitHubBranch( - url="https://github.com/kubernetes/autoscaler", ns="kubernetes", name="autoscaler", branch="master" + url="https://github.com/kubernetes/autoscaler", + ns="kubernetes", + name="autoscaler", + branch="master", ), ), - ("foo/bar:baz", GitHubBranch(url="https://github.com/foo/bar", ns="foo", name="bar", branch="baz")), + ( + "foo/bar:baz", + GitHubBranch(url="https://github.com/foo/bar", ns="foo", name="bar", branch="baz"), + ), ), ) @pytest.mark.parametrize("arg", ["source", "dest", "rebase"]) @@ -204,7 +213,10 @@ def test_app_credentials_valid_credentials_file_app_auth(self, mocked_run, get_v @patch("rebasebot.cli._get_github_app_wrapper") @patch("rebasebot.bot.run") def test_persistent_working_dir_when_not_specified( - self, mocked_run, mocked_get_github_app_wrapper, valid_args_dict + self, + mocked_run, + mocked_get_github_app_wrapper, + valid_args_dict, ): mocked_get_github_app_wrapper.return_value = MagicMock() diff --git a/tests/test_conflict_policy.py b/tests/test_conflict_policy.py index c332633..ce32f9f 100644 --- a/tests/test_conflict_policy.py +++ b/tests/test_conflict_policy.py @@ -11,7 +11,6 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - """Tests for --conflict-policy behavior.""" from __future__ import annotations @@ -37,7 +36,6 @@ \tec2 string } """ - # Upstream adds a new field/constant (between existing lines) _UPSTREAM_ADDED_CODE = """\ package main @@ -54,7 +52,6 @@ \tebsKmsKeyId string } """ - # Downstream carry patch reformats and adds its own field/constant # (conflicts with upstream because it modifies the same lines) _DOWNSTREAM_CARRY_CODE = """\ diff --git a/tests/test_rebases.py b/tests/test_rebases.py index f398188..99c71cd 100644 --- a/tests/test_rebases.py +++ b/tests/test_rebases.py @@ -417,7 +417,8 @@ def fake_repository_func(namespace, name): args.dry_run = False result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) mocked_message_slack.assert_called_once_with( - None, f"Repo {dest.clone_url} has PR {pr.html_url} with 'rebase/manual' label, aborting" + None, + f"Repo {dest.clone_url} has PR {pr.html_url} with 'rebase/manual' label, aborting", ) assert result @@ -477,7 +478,12 @@ def test_strict_and_excluded_commits(self, init_test_repositories, fake_github_p @patch("rebasebot.lifecycle_hooks._fetch_file_from_github") def test_lifecyclehooks_remote( - self, mock_fetch_file_from_github, init_test_repositories, fake_github_provider, tmpdir, caplog + self, + mock_fetch_file_from_github, + init_test_repositories, + fake_github_provider, + tmpdir, + caplog, ): source, rebase, dest = init_test_repositories @@ -513,7 +519,11 @@ def test_lifecyclehooks_remote( result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) mock_fetch_file_from_github.assert_called_once_with( - ANY, "openshift-eng", "rebasebot", "main", "tests/data/test-hook-script.sh" + ANY, + "openshift-eng", + "rebasebot", + "main", + "tests/data/test-hook-script.sh", ) # mock_fetch_branch.assert_called_once_with( # ANY, "github.com/openshift-eng/rebasebot", "main", @@ -740,7 +750,10 @@ def test_always_run_hooks_when_no_rebase_needed(self, init_test_repositories, fa assert "post-rebase-hook.success" in os.listdir(tmpdir) def test_hooks_not_run_when_no_rebase_needed_and_flag_false( - self, init_test_repositories, fake_github_provider, tmpdir + self, + init_test_repositories, + fake_github_provider, + tmpdir, ): """Test that hooks DON'T run when --always-run-hooks is False and no rebase is needed.""" source, rebase, dest = init_test_repositories From fccb6518d1f9e76f797cf79d68bdde05276973c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Ma=C5=88=C3=A1k?= Date: Tue, 28 Apr 2026 13:58:23 +0200 Subject: [PATCH 2/2] Add pause on error, continue functionality --- rebasebot/bot.py | 547 ++++++++++++---- rebasebot/cli.py | 54 +- rebasebot/lifecycle_hooks.py | 34 +- rebasebot/resume_flow.py | 873 +++++++++++++++++++++++++ rebasebot/resume_state.py | 153 +++++ tests/rebase_test_support.py | 125 ++++ tests/test_cli.py | 91 +++ tests/test_rebase_resume_conflicts.py | 537 +++++++++++++++ tests/test_rebase_resume_hooks.py | 494 ++++++++++++++ tests/test_rebase_resume_validation.py | 189 ++++++ tests/test_rebases.py | 349 ++++------ 11 files changed, 3080 insertions(+), 366 deletions(-) create mode 100644 rebasebot/resume_flow.py create mode 100644 rebasebot/resume_state.py create mode 100644 tests/rebase_test_support.py create mode 100644 tests/test_rebase_resume_conflicts.py create mode 100644 tests/test_rebase_resume_hooks.py create mode 100644 tests/test_rebase_resume_validation.py diff --git a/rebasebot/bot.py b/rebasebot/bot.py index 06ad780..7015619 100755 --- a/rebasebot/bot.py +++ b/rebasebot/bot.py @@ -31,9 +31,10 @@ from github3.repos.commit import ShortCommit from github3.repos.repo import Repository -from rebasebot import lifecycle_hooks +from rebasebot import lifecycle_hooks, resume_flow, resume_state from rebasebot.github import GithubAppProvider, GitHubBranch from rebasebot.lifecycle_hooks import LifecycleHookScriptException +from rebasebot.resume_flow import PauseRebaseTaskException, PausedRebaseException, ResumeFlowException class RepoException(Exception): @@ -46,6 +47,10 @@ class PullRequestUpdateException(Exception): """An error signaling an issue in updating a pull request""" +class UnresolvedConflictException(PauseRebaseTaskException): + """Raised when a cherry-pick conflict requires manual resolution.""" + + logging.basicConfig(format="%(levelname)s - %(message)s", stream=sys.stdout, level=logging.INFO) MERGE_TMP_BRANCH = "merge-tmp" _COMMIT_LOG_FORMAT = "--pretty=format:%H || %s || %aE" @@ -404,12 +409,13 @@ def _check_upstream_content_loss(gitwd: git.Repo, source_branch: str, only_files return results -def _safe_cherry_pick( # hi +def _safe_cherry_pick( gitwd: git.Repo, sha: str, source_branch: str, conflict_policy: str, commit_description: str, + pause_on_conflict: bool = False, ) -> None: """ Cherry-pick a commit with conflict detection based on conflict_policy. @@ -429,6 +435,8 @@ def _safe_cherry_pick( # hi gitwd.git.cherry_pick(f"{sha}", "-Xtheirs") except git.GitCommandError as ex: if not _resolve_rebase_conflicts(gitwd): + if pause_on_conflict: + raise UnresolvedConflictException(f"Git rebase failed: {ex}") from ex raise RepoException(f"Git rebase failed: {ex}") from ex # If no conflicts were detected, -Xtheirs had no effect — skip check @@ -453,27 +461,36 @@ def _safe_cherry_pick( # hi if conflict_policy == "strict": files = ", ".join(f for f, _ in lost_content) - raise RepoException( + message = ( f"Upstream content was lost in [{files}] after " f"cherry-picking '{commit_description}'. " f"-Xtheirs resolved a content conflict by dropping " f"upstream additions. Manual resolution is required." ) + if pause_on_conflict: + raise PauseRebaseTaskException( + message, + pause_reason=message, + resolution_instructions=( + "Review the paused commit, restore the missing upstream content, " + "amend it or add a follow-up commit, then rerun rebasebot with --continue." + ), + ) + raise RepoException(message) -def _do_rebase( +def _build_rebase_tasks( *, - gitwd: git.Repo, source: GitHubBranch, dest: GitHubBranch, + gitwd: git.Repo, source_repo: Repository, tag_policy: str, - conflict_policy: str = "auto", bot_emails: list, exclude_commits: list, update_go_modules: bool, -) -> None: - logging.info("Performing rebase") +) -> list[resume_state.ResumeTask]: + tasks: list[resume_state.ResumeTask] = [] allow_bot_squash = len(bot_emails) > 0 if allow_bot_squash: @@ -512,34 +529,276 @@ def _do_rebase( commits_to_squash[email].append({"sha": sha, "commit_message": commit_message}) continue - logging.info("Picking commit: %s - %s", sha, commit_message) - - _safe_cherry_pick( - gitwd=gitwd, - sha=sha, - source_branch=source.branch, - conflict_policy=conflict_policy, - commit_description=f"{sha} - {commit_message}", + tasks.append( + resume_state.ResumeTask( + kind="pick", + sha=sha, + source_branch=source.branch, + commit_description=f"{sha} - {commit_message}", + ) ) # Here we cherry-pick the bot's commits and then squash them together # We also want the newest bot commit message to represent the squashed commits if allow_bot_squash: for key, value in commits_to_squash.items(): - logging.info("Squashing commits for bot: %s: %s", key, value) for commit in value: - _safe_cherry_pick( - gitwd=gitwd, - sha=commit["sha"], - source_branch=source.branch, - conflict_policy=conflict_policy, - commit_description=f"{commit['sha']} - {commit['commit_message']}", + tasks.append( + resume_state.ResumeTask( + kind="pick", + sha=commit["sha"], + source_branch=source.branch, + commit_description=f"{commit['sha']} - {commit['commit_message']}", + ) ) - gitwd.git.reset("--soft", f"HEAD~{len(value)}") + tasks.append( + resume_state.ResumeTask( + kind="squash", + reset_count=len(value), + commit_message=value[-1]["commit_message"], + author=key, + ) + ) + + return tasks + + +def _build_art_pr_tasks(dest_repo: Repository, dest: GitHubBranch, gitwd: git.Repo) -> list[resume_state.ResumeTask]: + tasks: list[resume_state.ResumeTask] = [] + logging.info("Checking for ART pull request") + for pull_request in dest_repo.pull_requests(state="open", base=f"{dest.branch}"): + assert isinstance(pull_request, ShortPullRequest) # type hint + if "consistent with ART" in pull_request.title and pull_request.user.login == "openshift-bot": + logging.info(f"Found open ART image update pull requst: {pull_request.title}") + remote = pull_request.head.repository + remote_name = remote.name + if remote_name in gitwd.remotes: + gitwd.remotes[remote_name].set_url(remote.html_url) + else: + gitwd.create_remote(remote_name, remote.html_url) + + gitwd.remotes[remote_name].fetch(pull_request.head.ref) + + for commit in pull_request.commits(): + assert isinstance(commit, ShortCommit) + tasks.append( + resume_state.ResumeTask( + kind="pick", + sha=commit.sha, + source_branch=dest.branch, + commit_description=f"ART PR commit {commit.sha}", + ) + ) + + return tasks + + +def _report_manual_intervention( + source: GitHubBranch, + dest: GitHubBranch, + slack_webhook: str, + ex: Exception, +) -> None: + logging.error( + "Manual intervention is needed to rebase %s:%s into %s/%s:%s", + source.url, + source.branch, + dest.ns, + dest.name, + dest.branch, + ) + logging.error("Failure reason: %s", ex) + _message_slack( + slack_webhook, + f"Manual intervention is needed to rebase " + f"{source.url}:{source.branch} " + f"into {dest.ns}/{dest.name}:{dest.branch}: " + f"{ex}", + ) + + +def _execute_resume_tasks( # pylint: disable=too-many-arguments,too-many-positional-arguments + *, + gitwd: git.Repo, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + working_dir: str, + tasks: list[resume_state.ResumeTask], + phase: resume_state.ResumePhase, + conflict_policy: str, + pause_on_conflict: bool, + future_art_tasks: list[resume_state.ResumeTask] | None = None, +) -> None: + resume_flow.execute_rebase_tasks( + gitwd=gitwd, + source=source, + dest=dest, + rebase=rebase, + working_dir=working_dir, + tasks=tasks, + phase=phase, + conflict_policy=conflict_policy, + pause_on_conflict=pause_on_conflict, + safe_cherry_pick=_safe_cherry_pick, + pause_exception_cls=PauseRebaseTaskException, + future_art_tasks=future_art_tasks, + ) - newest_bot_commit_message = value[-1]["commit_message"] - gitwd.git.commit("-m", newest_bot_commit_message, "--author", key) +def _run_post_rebase_and_art( # pylint: disable=too-many-arguments,too-many-positional-arguments + *, + gitwd: git.Repo, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + working_dir: str, + hooks: lifecycle_hooks.LifecycleHooks, + art_tasks: list[resume_state.ResumeTask], + conflict_policy: str, + pause_on_conflict: bool, + run_art_tasks: bool = True, + post_rebase_start_script_index: int = 0, +) -> None: + flow_args = { + "gitwd": gitwd, + "source": source, + "dest": dest, + "rebase": rebase, + "working_dir": working_dir, + } + resume_flow.execute_hook_with_resume( + hook=lifecycle_hooks.LifecycleHook.POST_REBASE, + hooks=hooks, + phase=resume_state.ResumePhase.POST_REBASE, + art_tasks=art_tasks, + start_script_index=post_rebase_start_script_index, + **flow_args, + ) + if run_art_tasks: + _execute_resume_tasks( + tasks=art_tasks, + phase=resume_state.ResumePhase.ART_PR, + conflict_policy=conflict_policy, + pause_on_conflict=pause_on_conflict, + **flow_args, + ) + + +def _run_pre_carry_through_art( # pylint: disable=too-many-arguments,too-many-positional-arguments + *, + gitwd: git.Repo, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + working_dir: str, + source_repo: Repository, + dest_repo: Repository, + hooks: lifecycle_hooks.LifecycleHooks, + tag_policy: str, + conflict_policy: str, + bot_emails: list, + exclude_commits: list, + update_go_modules: bool, + pause_on_conflict: bool, + pre_carry_start_script_index: int = 0, + post_rebase_start_script_index: int = 0, +) -> None: + flow_args = { + "gitwd": gitwd, + "source": source, + "dest": dest, + "rebase": rebase, + "working_dir": working_dir, + } + resume_flow.execute_hook_with_resume( + hook=lifecycle_hooks.LifecycleHook.PRE_CARRY_COMMIT, + hooks=hooks, + phase=resume_state.ResumePhase.PRE_CARRY_COMMIT, + start_script_index=pre_carry_start_script_index, + **flow_args, + ) + art_tasks = _build_art_pr_tasks(dest_repo, dest, gitwd) + carry_tasks = _build_rebase_tasks( + gitwd=gitwd, + source=source, + dest=dest, + source_repo=source_repo, + tag_policy=tag_policy, + bot_emails=bot_emails, + exclude_commits=exclude_commits, + update_go_modules=update_go_modules, + ) + _execute_resume_tasks( + tasks=carry_tasks, + phase=resume_state.ResumePhase.CARRY_COMMITS, + conflict_policy=conflict_policy, + pause_on_conflict=pause_on_conflict, + future_art_tasks=art_tasks, + **flow_args, + ) + _run_post_rebase_and_art( + hooks=hooks, + art_tasks=art_tasks, + conflict_policy=conflict_policy, + pause_on_conflict=pause_on_conflict, + post_rebase_start_script_index=post_rebase_start_script_index, + **flow_args, + ) + + +def _run_always_run_hook_subset( + *, + hooks: lifecycle_hooks.LifecycleHooks, + gitwd: git.Repo, + working_dir: str, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + include_pre_rebase: bool, + pre_rebase_start_script_index: int = 0, + pre_carry_start_script_index: int = 0, + post_rebase_start_script_index: int = 0, +) -> None: + flow_args = { + "gitwd": gitwd, + "source": source, + "dest": dest, + "rebase": rebase, + "working_dir": working_dir, + } + if include_pre_rebase: + resume_flow.execute_hook_with_resume( + hook=lifecycle_hooks.LifecycleHook.PRE_REBASE, + hooks=hooks, + phase=resume_state.ResumePhase.PRE_REBASE, + start_script_index=pre_rebase_start_script_index, + **flow_args, + ) + resume_flow.execute_hook_with_resume( + hook=lifecycle_hooks.LifecycleHook.PRE_CARRY_COMMIT, + hooks=hooks, + phase=resume_state.ResumePhase.PRE_CARRY_COMMIT, + start_script_index=pre_carry_start_script_index, + **flow_args, + ) + resume_flow.execute_hook_with_resume( + hook=lifecycle_hooks.LifecycleHook.POST_REBASE, + hooks=hooks, + phase=resume_state.ResumePhase.POST_REBASE, + start_script_index=post_rebase_start_script_index, + **flow_args, + ) + + +def _build_flow_actions() -> resume_flow.FlowActions: + return resume_flow.FlowActions( + needs_rebase=_needs_rebase, + prepare_rebase_branch=_prepare_rebase_branch, + build_rebase_tasks=_build_rebase_tasks, + build_art_pr_tasks=_build_art_pr_tasks, + execute_rebase_tasks=_execute_resume_tasks, + ) def _prepare_rebase_branch(gitwd: git.Repo, source: GitHubBranch, dest: GitHubBranch) -> None: @@ -646,41 +905,6 @@ def _resolve_rebase_conflicts(gitwd: git.Repo) -> bool: return _resolve_rebase_conflicts(gitwd) -def _cherrypick_art_pull_request( - gitwd: git.Repo, - dest_repo: Repository, - dest: GitHubBranch, - conflict_policy: str = "auto", # hi -) -> None: - """ - Looks at the destination repository and if there is an open ART pull request - that updates the build image, it includes it in the rebase. - """ - logging.info("Checking for ART pull request") - for pull_request in dest_repo.pull_requests(state="open", base=f"{dest.branch}"): - assert isinstance(pull_request, ShortPullRequest) # type hint - if "consistent with ART" in pull_request.title and pull_request.user.login == "openshift-bot": - logging.info(f"Found open ART image update pull requst: {pull_request.title}") - remote = pull_request.head.repository - remote_name = remote.name - if remote_name in gitwd.remotes: - gitwd.remotes[remote_name].set_url(remote.html_url) - else: - gitwd.create_remote(remote_name, remote.html_url) - - gitwd.remotes[remote_name].fetch(pull_request.head.ref) - - for commit in pull_request.commits(): - assert isinstance(commit, ShortCommit) - _safe_cherry_pick( - gitwd=gitwd, - sha=commit.sha, - source_branch=dest.branch, - conflict_policy=conflict_policy, - commit_description=f"ART PR commit {commit.sha}", - ) - - def _is_push_required(gitwd: git.Repo, rebase: GitHubBranch) -> bool: # Check if there is nothing to update in the open rebase PR. if rebase.branch in gitwd.remotes.rebase.refs: @@ -785,6 +1009,7 @@ def _init_working_dir( git_username: str, git_email: str, workdir: str, + preserve_rebase_state: bool = False, ) -> git.Repo: gitwd = git.Repo.init(path=workdir) @@ -793,6 +1018,10 @@ def _init_working_dir( # checks, wrong commit filtering, wrong cherry-pick detection). Reinitializing # .git is the only safe way to clear them. if "source" in gitwd.remotes and gitwd.remotes["source"].url != source.url: + if preserve_rebase_state: + raise RepoException( + "Cannot continue paused run because the source repository URL changed in the working directory." + ) logging.warning( "Source URL changed from %s to %s; reinitializing working directory to remove stale refs", gitwd.remotes["source"].url, @@ -876,17 +1105,18 @@ def _init_working_dir( logging.info("Fetching existing rebase branch") gitwd.remotes.rebase.fetch(rebase.branch) - # Reset the existing rebase branch to match the source branch - # or create a new rebase branch based on the source branch. - head_commit = gitwd.git.rev_parse(source_ref) - if "rebase" in gitwd.heads: - gitwd.heads.rebase.set_commit(head_commit) - else: - gitwd.create_head("rebase", head_commit) - gitwd.git.checkout("rebase", force=True) - gitwd.head.reset(index=True, working_tree=True) - # Clean any untracked files when reusing rebase directory - gitwd.git.clean("-fd") + if not preserve_rebase_state: + # Reset the existing rebase branch to match the source branch + # or create a new rebase branch based on the source branch. + head_commit = gitwd.git.rev_parse(source_ref) + if "rebase" in gitwd.heads: + gitwd.heads.rebase.set_commit(head_commit) + else: + gitwd.create_head("rebase", head_commit) + gitwd.git.checkout("rebase", force=True) + gitwd.head.reset(index=True, working_tree=True) + # Clean any untracked files when reusing rebase directory + gitwd.git.clean("-fd") return gitwd @@ -979,6 +1209,33 @@ def _report_result( # pylint: disable=R0917 _message_slack(slack_webhook, message) +def _prepare_resume_state( + *, + continue_run: bool, + gitwd: git.Repo, + working_dir: str, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, +) -> resume_state.ResumeState | None: + if not continue_run: + return None + + state = resume_flow.load_and_validate_resume_state( + gitwd=gitwd, + working_dir=working_dir, + source=source, + dest=dest, + rebase=rebase, + ) + + # Resume state is only valid for explicit pause points. Clear it before + # re-entering the flow so generic failures do not leave stale progress + # snapshots behind. + resume_state.clear_resume_state(working_dir) + return state + + def run( *, source: GitHubBranch, @@ -999,6 +1256,9 @@ def run( ignore_manual_label: bool = False, always_run_hooks: bool = False, title_prefix: str = "", + pause_on_conflict: bool = False, + continue_run: bool = False, + retry_failed_step: bool = False, ) -> bool: """Run Rebase Bot.""" gh_app = github_app_provider.github_app @@ -1037,6 +1297,17 @@ def run( except FileExistsError: pass + if continue_run: + if not resume_state.has_resume_state(working_dir): + logging.error( + "No paused resume state found in %s. Run without --continue or resolve a paused run first.", + working_dir, + ) + return False + elif resume_state.has_resume_state(working_dir): + logging.error("Working directory %s contains paused resume state. Use --continue to resume it.", working_dir) + return False + try: gitwd = _init_working_dir( source=source, @@ -1046,6 +1317,7 @@ def run( git_username=git_username, git_email=git_email, workdir=working_dir, + preserve_rebase_state=continue_run, ) except Exception as ex: logging.exception( @@ -1064,6 +1336,19 @@ def run( ) return False + try: + validated_resume_state = _prepare_resume_state( + continue_run=continue_run, + gitwd=gitwd, + working_dir=working_dir, + source=source, + dest=dest, + rebase=rebase, + ) + except ResumeFlowException as ex: + _report_manual_intervention(source, dest, slack_webhook, ex) + return False + try: hooks.fetch_hook_scripts(gitwd=gitwd, github_app_provider=github_app_provider) except Exception as ex: @@ -1071,43 +1356,35 @@ def run( _message_slack(slack_webhook, f"Failed to fetch lifecycle hook scripts: {ex}") return False + flow_actions = _build_flow_actions() try: - needs_rebase = _needs_rebase(gitwd, source, dest) - if needs_rebase: - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.PRE_REBASE) - _prepare_rebase_branch(gitwd, source, dest) - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.PRE_CARRY_COMMIT) - _do_rebase( - gitwd=gitwd, - source=source, - dest=dest, - source_repo=source_repo, - tag_policy=tag_policy, - conflict_policy=conflict_policy, - bot_emails=bot_emails, - exclude_commits=exclude_commits, - update_go_modules=update_go_modules, - ) - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.POST_REBASE) - _cherrypick_art_pull_request(gitwd, dest_repo, dest, conflict_policy) - elif always_run_hooks: - # Run hooks without rebase operations when --always-run-hooks is enabled - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.PRE_REBASE) - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.PRE_CARRY_COMMIT) - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.POST_REBASE) - - except (RepoException, LifecycleHookScriptException) as ex: - logging.error( - f"Manual intervention is needed to rebase {source.url}:{source.branch} " - f"into {dest.ns}/{dest.name}:{dest.branch}" - ) - _message_slack( - slack_webhook, - f"Manual intervention is needed to rebase " - f"{source.url}:{source.branch} " - f"into {dest.ns}/{dest.name}:{dest.branch}: " - f"{ex}", + flow_result = resume_flow.execute_flow( + gitwd=gitwd, + source=source, + dest=dest, + rebase=rebase, + working_dir=working_dir, + source_repo=source_repo, + dest_repo=dest_repo, + hooks=hooks, + tag_policy=tag_policy, + conflict_policy=conflict_policy, + bot_emails=bot_emails, + exclude_commits=exclude_commits, + update_go_modules=update_go_modules, + always_run_hooks=always_run_hooks, + pause_on_conflict=pause_on_conflict, + retry_failed_step=retry_failed_step, + actions=flow_actions, + state=validated_resume_state, ) + needs_rebase = flow_result.needs_rebase + except PausedRebaseException as ex: + logging.warning(str(ex)) + _message_slack(slack_webhook, str(ex)) + raise + except (RepoException, ResumeFlowException, LifecycleHookScriptException) as ex: + _report_manual_intervention(source, dest, slack_webhook, ex) return False except Exception as ex: logging.exception( @@ -1125,6 +1402,7 @@ def run( if dry_run: logging.info("Dry run mode is enabled. Do not create a PR.") + resume_state.clear_resume_state(working_dir) return True push_required = _is_push_required(gitwd, rebase) @@ -1135,21 +1413,24 @@ def run( # Push the rebase branch to the remote repository. if push_required: logging.info("Existing rebase branch needs to be updated.") + flow_args = { + "gitwd": gitwd, + "source": source, + "dest": dest, + "rebase": rebase, + "working_dir": working_dir, + } try: - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.PRE_PUSH_REBASE_BRANCH) + if not flow_result.skip_pre_push_rebase_branch_hook: + resume_flow.execute_hook_with_resume( + hook=lifecycle_hooks.LifecycleHook.PRE_PUSH_REBASE_BRANCH, + hooks=hooks, + phase=resume_state.ResumePhase.PRE_PUSH_REBASE_BRANCH, + **flow_args, + ) _push_rebase_branch(gitwd, rebase) except LifecycleHookScriptException as ex: - logging.error( - f"Manual intervention is needed to rebase {source.url}:{source.branch} " - f"into {dest.ns}/{dest.name}:{dest.branch}" - ) - _message_slack( - slack_webhook, - f"Manual intervention is needed to rebase " - f"{source.url}:{source.branch} " - f"into {dest.ns}/{dest.name}:{dest.branch}: " - f"{ex}", - ) + _report_manual_intervention(source, dest, slack_webhook, ex) return False except Exception as ex: logging.exception(f"error pushing to {rebase.ns}/{rebase.name}:{rebase.branch}") @@ -1173,7 +1454,20 @@ def run( try: if not pr_available: - hooks.execute_scripts_for_hook(hook=lifecycle_hooks.LifecycleHook.PRE_CREATE_PR) + flow_args = { + "gitwd": gitwd, + "source": source, + "dest": dest, + "rebase": rebase, + "working_dir": working_dir, + } + if not flow_result.skip_pre_create_pr_hook: + resume_flow.execute_hook_with_resume( + hook=lifecycle_hooks.LifecycleHook.PRE_CREATE_PR, + hooks=hooks, + phase=resume_state.ResumePhase.PRE_CREATE_PR, + **flow_args, + ) pr_required = _is_pr_required(gitwd, rebase, dest) if pr_required: pr_url = _create_pr( @@ -1188,17 +1482,7 @@ def run( logging.info("No PR required - no changes between rebase and dest.") pr_url = None except LifecycleHookScriptException as ex: - logging.error( - f"Manual intervention is needed to rebase {source.url}:{source.branch} " - f"into {dest.ns}/{dest.name}:{dest.branch}" - ) - _message_slack( - slack_webhook, - f"Manual intervention is needed to rebase " - f"{source.url}:{source.branch} " - f"into {dest.ns}/{dest.name}:{dest.branch}: " - f"{ex}", - ) + _report_manual_intervention(source, dest, slack_webhook, ex) return False except requests.exceptions.HTTPError as ex: logging.error(f"Failed to create a pull request: {ex}\n Response: %s", ex.response.text) @@ -1212,4 +1496,5 @@ def run( return False _report_result(needs_rebase, pr_required, pr_available, pr_url, dest.url, slack_webhook) + resume_state.clear_resume_state(working_dir) return True diff --git a/rebasebot/cli.py b/rebasebot/cli.py index efbd3ff..569071d 100755 --- a/rebasebot/cli.py +++ b/rebasebot/cli.py @@ -20,7 +20,7 @@ import sys import tempfile -from rebasebot import bot, lifecycle_hooks +from rebasebot import bot, lifecycle_hooks, resume_state from rebasebot.github import GithubAppProvider, GitHubBranch, parse_github_branch @@ -182,6 +182,31 @@ def check_source_repo_args(namespace): required=False, help="When enabled, the bot will not create or update PR.", ) + parser.add_argument( + "--pause-on-conflict", + action="store_true", + default=False, + required=False, + help=( + "Pause and persist resume state when a cherry-pick conflict needs manual resolution, " + "or when strict conflict policy detects dropped upstream content after a cherry-pick." + ), + ) + parser.add_argument( + "--continue", + dest="continue_run", + action="store_true", + default=False, + required=False, + help="Continue a previously paused rebasebot run from the next saved step in the configured working directory.", + ) + parser.add_argument( + "--retry-failed-step", + action="store_true", + default=False, + required=False, + help="When resuming a paused hook failure, retry the failed hook script instead of skipping to the next one.", + ) parser.add_argument( "--tag-policy", default="none", @@ -329,8 +354,19 @@ def rebasebot_run(args, slack_webhook, github_app_wrapper): args.working_dir = working_dir original_cwd = os.getcwd() try: + if args.continue_run is True and args.source_repo is not None: + try: + persisted_state = resume_state.read_resume_state(working_dir) + except resume_state.ResumeStateError as e: + logging.error( + f"Error loading resume state before continue: {e}", + exc_info=True, + ) + sys.exit(1) + args.source = persisted_state.source.to_github_branch() + try: - if args.source_repo is not None: + if args.source_repo is not None and args.continue_run is not True: lifecycle_hooks.run_source_repo_hook( args=args, github_app_wrapper=github_app_wrapper, @@ -371,6 +407,9 @@ def rebasebot_run(args, slack_webhook, github_app_wrapper): hooks=hooks, always_run_hooks=args.always_run_hooks, title_prefix=args.title_prefix, + pause_on_conflict=args.pause_on_conflict is True, + continue_run=args.continue_run is True, + retry_failed_step=args.retry_failed_step is True, ) finally: os.chdir(original_cwd) @@ -400,10 +439,13 @@ def main(): gh_user_token_path=args.github_user_token, ) - if rebasebot_run(args, slack_webhook, github_app_wrapper): - sys.exit(0) - else: - sys.exit(1) + try: + if rebasebot_run(args, slack_webhook, github_app_wrapper): + sys.exit(0) + else: + sys.exit(1) + except bot.PausedRebaseException: + sys.exit(3) if __name__ == "__main__": diff --git a/rebasebot/lifecycle_hooks.py b/rebasebot/lifecycle_hooks.py index ff1d54b..7d89947 100644 --- a/rebasebot/lifecycle_hooks.py +++ b/rebasebot/lifecycle_hooks.py @@ -30,7 +30,18 @@ class LifecycleHookScriptException(Exception): - """LifecycleHookScriptException is a exception raised as a result of lifecycle hook script failure.""" + """LifecycleHookScriptException is raised when a lifecycle hook script fails.""" + + def __init__( + self, + message: str, + *, + script_index: int | None = None, + script_location: str | None = None, + ): + super().__init__(message) + self.script_index = script_index + self.script_location = script_location class LifecycleHook(Enum): @@ -339,9 +350,18 @@ def fetch_hook_scripts(self, gitwd: git.Repo, github_app_provider: GithubAppProv for script in hooks: script.fetch_script(temp_hook_dir=self.tmp_hook_scripts_dir, gitwd=gitwd, github=github_app_provider) - def execute_scripts_for_hook(self, hook: LifecycleHook): - """Executes all scripts in the given lifecycle hook.""" - for script in self.hooks.get(hook, []): + def get_scripts_for_hook(self, hook: LifecycleHook) -> list[LifecycleHookScript]: + """Returns the configured scripts for the given lifecycle hook.""" + return self.hooks.get(hook, []) + + def get_script_locations_for_hook(self, hook: LifecycleHook) -> list[str]: + """Returns configured script locations for the given lifecycle hook.""" + return [script.script_location for script in self.get_scripts_for_hook(hook)] + + def execute_scripts_for_hook(self, hook: LifecycleHook, start_index: int = 0): + """Executes hook scripts starting at the given index.""" + scripts = self.get_scripts_for_hook(hook) + for index, script in enumerate(scripts[start_index:], start=start_index): logging.info(f"Running {hook} lifecycle hook {script}") try: result = script(cwd=self.working_dir) @@ -355,4 +375,8 @@ def execute_scripts_for_hook(self, hook: LifecycleHook): except subprocess.CalledProcessError as err: logging.error(f"Script {script} failed with exit code {err.returncode}") message = f"{hook} script {script} failed with exit-code {err.returncode}" - raise LifecycleHookScriptException(message) from err + raise LifecycleHookScriptException( + message, + script_index=index, + script_location=script.script_location, + ) from err diff --git a/rebasebot/resume_flow.py b/rebasebot/resume_flow.py new file mode 100644 index 0000000..60a1e6c --- /dev/null +++ b/rebasebot/resume_flow.py @@ -0,0 +1,873 @@ +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Any, Callable + +import git +from github3.repos.repo import Repository + +from rebasebot import lifecycle_hooks, resume_state +from rebasebot.github import GitHubBranch +from rebasebot.lifecycle_hooks import LifecycleHookScriptException + + +class PausedRebaseException(Exception): + """Raised when rebasebot intentionally pauses for manual resolution.""" + + +class ResumeFlowException(Exception): + """Raised when a paused run cannot be safely resumed.""" + + +class PauseRebaseTaskException(Exception): + """Raised when a task should pause and persist resume state.""" + + def __init__( + self, + message: str, + *, + pause_reason: str | None = None, + resolution_instructions: str | None = None, + ) -> None: + super().__init__(message) + self.pause_reason = pause_reason + self.resolution_instructions = resolution_instructions + + +@dataclass +class FlowResult: + """Carries flow results needed by later publish steps.""" + + needs_rebase: bool + skip_pre_push_rebase_branch_hook: bool = False + skip_pre_create_pr_hook: bool = False + + +@dataclass +class FlowActions: + """Operations that the flow engine delegates back to the main bot flow.""" + + needs_rebase: Callable[[git.Repo, GitHubBranch, GitHubBranch], bool] + prepare_rebase_branch: Callable[[git.Repo, GitHubBranch, GitHubBranch], None] + build_rebase_tasks: Callable[..., list[resume_state.ResumeTask]] + build_art_pr_tasks: Callable[[Repository, GitHubBranch, git.Repo], list[resume_state.ResumeTask]] + execute_rebase_tasks: Callable[..., None] + + +@dataclass +class FlowContext: + """Shared immutable inputs and computed state for a flow run.""" + + gitwd: git.Repo + source: GitHubBranch + dest: GitHubBranch + rebase: GitHubBranch + working_dir: str + source_repo: Repository + dest_repo: Repository + hooks: lifecycle_hooks.LifecycleHooks + tag_policy: str + conflict_policy: str + bot_emails: list + exclude_commits: list + update_go_modules: bool + always_run_hooks: bool + pause_on_conflict: bool + retry_failed_step: bool + resume: resume_state.ResumeState | None + needs_rebase: bool + runtime_art_tasks: list[resume_state.ResumeTask] | None = None + + def is_resume(self) -> bool: + return self.resume is not None + + def resume_phase(self) -> resume_state.ResumePhase | None: + if self.resume is None: + return None + return self.resume.phase + + def effective_pause_on_conflict(self) -> bool: + return True if self.is_resume() else self.pause_on_conflict + + def flow_args(self) -> dict[str, Any]: + return { + "gitwd": self.gitwd, + "source": self.source, + "dest": self.dest, + "rebase": self.rebase, + "working_dir": self.working_dir, + } + + def current_art_tasks(self) -> list[resume_state.ResumeTask]: + if self.runtime_art_tasks is not None: + return self.runtime_art_tasks + if self.resume is not None: + return self.resume.art_tasks + return [] + + def set_runtime_art_tasks(self, art_tasks: list[resume_state.ResumeTask]) -> None: + self.runtime_art_tasks = art_tasks + + +@dataclass(frozen=True) +class StepSpec: + """Describes one persisted phase in the unified step runner.""" + + phase: resume_state.ResumePhase + fresh_when: Callable[[FlowContext], bool] + run_fresh: Callable[[FlowContext, FlowActions], dict[str, bool] | None] + run_resume: Callable[[FlowContext, FlowActions, resume_state.ResumeState], dict[str, bool] | None] + terminal_on_resume: bool = False + + +_NO_REBASE_CONTINUE_MESSAGE = ( + "Cannot continue paused run because no rebase is needed and lifecycle hooks are not configured to run." +) + + +def persist_resume_state( + *, + gitwd: git.Repo, + working_dir: str, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + phase: resume_state.ResumePhase, + remaining_tasks: list[resume_state.ResumeTask], + art_tasks: list[resume_state.ResumeTask] | None = None, + current_task: resume_state.ResumeTask | None = None, + head_before_task: str | None = None, + head_at_pause: str | None = None, + allowed_untracked_files: list[str] | None = None, + next_hook_script_index: int | None = None, + hook_script_locations: list[str] | None = None, +) -> str: + state = resume_state.ResumeState( + source=resume_state.BranchState.from_github_branch(source), + dest=resume_state.BranchState.from_github_branch(dest), + rebase=resume_state.BranchState.from_github_branch(rebase), + source_head_sha=gitwd.commit(f"source/{source.branch}").hexsha, + dest_head_sha=gitwd.commit(f"dest/{dest.branch}").hexsha, + phase=phase, + remaining_tasks=remaining_tasks, + art_tasks=art_tasks or [], + current_task=current_task, + head_before_task=head_before_task, + head_at_pause=head_at_pause, + allowed_untracked_files=allowed_untracked_files or [], + next_hook_script_index=next_hook_script_index, + hook_script_locations=hook_script_locations, + ) + return resume_state.write_resume_state(working_dir, state) + + +def pause_rebase_for_resolution( + *, + gitwd: git.Repo, + working_dir: str, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + phase: resume_state.ResumePhase, + current_task: resume_state.ResumeTask, + remaining_tasks: list[resume_state.ResumeTask], + art_tasks: list[resume_state.ResumeTask], + head_before_task: str, + pause_reason: str | None = None, + resolution_instructions: str | None = None, +) -> None: + allowed_untracked_files = sorted(path for path in gitwd.untracked_files if path != resume_state.STATE_FILENAME) + state_path = persist_resume_state( + gitwd=gitwd, + working_dir=working_dir, + source=source, + dest=dest, + rebase=rebase, + phase=phase, + remaining_tasks=remaining_tasks, + art_tasks=art_tasks, + current_task=current_task, + head_before_task=head_before_task, + head_at_pause=gitwd.head.commit.hexsha, + allowed_untracked_files=allowed_untracked_files, + ) + default_resolution_instructions = ( + f"Resolve the conflict in {working_dir}, finish the cherry-pick with a commit or " + "'git cherry-pick --continue', then rerun rebasebot with --continue." + ) + message_parts = [f"Paused during {phase.display_name} while applying '{current_task.commit_description}'."] + if pause_reason is not None: + message_parts.append(pause_reason) + message_parts.append(resolution_instructions or default_resolution_instructions) + message_parts.append(f"Resume state saved to {state_path}.") + raise PausedRebaseException(" ".join(message_parts)) + + +def execute_hook_with_resume( + *, + hook: lifecycle_hooks.LifecycleHook, + hooks: lifecycle_hooks.LifecycleHooks, + gitwd: git.Repo, + working_dir: str, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + phase: resume_state.ResumePhase, + remaining_tasks: list[resume_state.ResumeTask] | None = None, + art_tasks: list[resume_state.ResumeTask] | None = None, + start_script_index: int = 0, +) -> None: + try: + hooks.execute_scripts_for_hook(hook=hook, start_index=start_script_index) + except LifecycleHookScriptException as ex: + state_path = persist_resume_state( + gitwd=gitwd, + working_dir=working_dir, + source=source, + dest=dest, + rebase=rebase, + phase=phase, + remaining_tasks=remaining_tasks or [], + art_tasks=art_tasks, + next_hook_script_index=ex.script_index + 1 if ex.script_index is not None else None, + hook_script_locations=hooks.get_script_locations_for_hook(hook), + ) + logging.warning( + "Saved resume state to %s after %s failed at %s. Continue with --continue to skip the failed script " + "and continue from the next saved step, or rerun with --continue --retry-failed-step after fixing " + "the issue to retry it.", + state_path, + phase.display_name, + ex.script_location or hook, + ) + raise + + +def resolve_hook_resume_index( + *, + hook: lifecycle_hooks.LifecycleHook, + hooks: lifecycle_hooks.LifecycleHooks, + state: resume_state.ResumeState, + retry_failed_step: bool = False, +) -> int: + current_hook_scripts = hooks.get_script_locations_for_hook(hook) + if state.hook_script_locations is None or state.next_hook_script_index is None: + if retry_failed_step: + raise ResumeFlowException( + f"Cannot retry failed step for {hook} because the saved hook position is unavailable." + ) + return len(current_hook_scripts) + + if current_hook_scripts != state.hook_script_locations: + raise ResumeFlowException( + f"Cannot continue paused run because configured {hook} scripts changed after the pause." + ) + + if not 0 <= state.next_hook_script_index <= len(current_hook_scripts): + raise ResumeFlowException( + f"Cannot continue paused run because saved {hook} script position is invalid." + ) + + if retry_failed_step: + if state.next_hook_script_index == 0: + raise ResumeFlowException( + f"Cannot retry failed step for {hook} because the saved hook position is invalid." + ) + return state.next_hook_script_index - 1 + + return state.next_hook_script_index + + +def execute_rebase_tasks( # pylint: disable=too-many-arguments,too-many-positional-arguments + *, + gitwd: git.Repo, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + working_dir: str, + tasks: list[resume_state.ResumeTask], + phase: resume_state.ResumePhase, + conflict_policy: str, + pause_on_conflict: bool, + safe_cherry_pick: Callable[..., None], + pause_exception_cls: type[Exception], + future_art_tasks: list[resume_state.ResumeTask] | None = None, +) -> None: + for index, task in enumerate(tasks): + if task.kind == "squash": + logging.info("Squashing commits for bot: %s", task.author) + gitwd.git.reset("--soft", f"HEAD~{task.reset_count}") + gitwd.git.commit("-m", task.commit_message, "--author", task.author) + continue + + logging.info("Picking commit: %s", task.commit_description) + head_before_task = gitwd.head.commit.hexsha + try: + safe_cherry_pick( + gitwd=gitwd, + sha=task.sha, + source_branch=task.source_branch, + conflict_policy=conflict_policy, + commit_description=task.commit_description, + pause_on_conflict=pause_on_conflict, + ) + except pause_exception_cls as ex: + pause_rebase_for_resolution( + gitwd=gitwd, + working_dir=working_dir, + source=source, + dest=dest, + rebase=rebase, + phase=phase, + current_task=task, + remaining_tasks=tasks[index + 1 :], + art_tasks=future_art_tasks or [], + head_before_task=head_before_task, + pause_reason=getattr(ex, "pause_reason", None), + resolution_instructions=getattr(ex, "resolution_instructions", None), + ) + + +def validate_resume_request( + *, + state: resume_state.ResumeState, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, +) -> None: + expected = { + "source": state.source.to_github_branch(), + "dest": state.dest.to_github_branch(), + "rebase": state.rebase.to_github_branch(), + } + actual = {"source": source, "dest": dest, "rebase": rebase} + for key, expected_branch in expected.items(): + if actual[key] != expected_branch: + raise ResumeFlowException( + f"Cannot continue paused run: {key} does not match resume state. " + f"Expected {expected_branch.ns}/{expected_branch.name}:{expected_branch.branch}." + ) + + +def validate_resume_git_state( + *, + gitwd: git.Repo, + state: resume_state.ResumeState, +) -> None: + current_head = gitwd.head.commit.hexsha + + if gitwd.head.is_detached: + raise ResumeFlowException( + "Cannot continue paused run while HEAD is detached. Check out the local rebase branch first." + ) + + if state.phase in { + resume_state.ResumePhase.CARRY_COMMITS, + resume_state.ResumePhase.POST_REBASE, + resume_state.ResumePhase.ART_PR, + resume_state.ResumePhase.PRE_PUSH_REBASE_BRANCH, + resume_state.ResumePhase.PRE_CREATE_PR, + } and gitwd.active_branch.name != "rebase": + raise ResumeFlowException( + f"Cannot continue paused run from branch '{gitwd.active_branch.name}'. " + "Check out the local rebase branch first." + ) + + cherry_pick_head = os.path.join(gitwd.git_dir, "CHERRY_PICK_HEAD") + if os.path.exists(cherry_pick_head): + raise ResumeFlowException( + "Conflict resolution is still in progress. Finish it with a commit or " + "'git cherry-pick --continue' before rerunning rebasebot --continue." + ) + + dirty_status = gitwd.git.status("--porcelain", "--untracked-files=no") + if dirty_status: + raise ResumeFlowException( + "Cannot continue paused run with staged or modified tracked files. " + "Commit, stash, or discard those changes first." + ) + + unexpected_untracked = sorted( + set(gitwd.untracked_files) - set(state.allowed_untracked_files or []) - {resume_state.STATE_FILENAME} + ) + if unexpected_untracked: + raise ResumeFlowException( + f"Cannot continue paused run with unexpected untracked files present: {', '.join(unexpected_untracked)}." + ) + + current_source_head = gitwd.commit(f"source/{state.source.branch}").hexsha + if current_source_head != state.source_head_sha: + raise ResumeFlowException( + "Cannot continue paused run because the source branch advanced after the pause. " + "Restart the rebase with the new upstream state." + ) + + current_dest_head = gitwd.commit(f"dest/{state.dest.branch}").hexsha + if current_dest_head != state.dest_head_sha: + raise ResumeFlowException( + "Cannot continue paused run because the destination branch advanced after the pause. " + "Restart the rebase with the new downstream state." + ) + + if ( + state.phase in {resume_state.ResumePhase.CARRY_COMMITS, resume_state.ResumePhase.ART_PR} + and state.current_task is not None + and state.head_before_task is not None + ): + paused_head = state.head_at_pause or state.head_before_task + if paused_head != state.head_before_task and current_head == paused_head: + raise ResumeFlowException( + "Cannot continue paused run because the paused commit was not changed after the pause. " + "Amend it, replace it, or drop it before rerunning rebasebot --continue." + ) + + if current_head == state.head_before_task: + logging.info( + "Paused task '%s' was skipped or resolved without a new commit; continuing with remaining tasks.", + state.current_task.commit_description, + ) + + return + +def load_and_validate_resume_state( + *, + gitwd: git.Repo, + working_dir: str, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, +) -> resume_state.ResumeState: + try: + state = resume_state.read_resume_state(working_dir) + except resume_state.ResumeStateError as err: + raise ResumeFlowException(str(err)) from err + validate_resume_request( + state=state, + source=source, + dest=dest, + rebase=rebase, + ) + validate_resume_git_state( + gitwd=gitwd, + state=state, + ) + return state + + +def build_flow_context( # pylint: disable=too-many-arguments,too-many-positional-arguments + *, + gitwd: git.Repo, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + working_dir: str, + source_repo: Repository, + dest_repo: Repository, + hooks: lifecycle_hooks.LifecycleHooks, + tag_policy: str, + conflict_policy: str, + bot_emails: list, + exclude_commits: list, + update_go_modules: bool, + always_run_hooks: bool, + pause_on_conflict: bool, + retry_failed_step: bool, + actions: FlowActions, + state: resume_state.ResumeState | None, +) -> FlowContext: + return FlowContext( + gitwd=gitwd, + source=source, + dest=dest, + rebase=rebase, + working_dir=working_dir, + source_repo=source_repo, + dest_repo=dest_repo, + hooks=hooks, + tag_policy=tag_policy, + conflict_policy=conflict_policy, + bot_emails=bot_emails, + exclude_commits=exclude_commits, + update_go_modules=update_go_modules, + always_run_hooks=always_run_hooks, + pause_on_conflict=pause_on_conflict, + retry_failed_step=retry_failed_step, + resume=state, + needs_rebase=actions.needs_rebase(gitwd, source, dest), + ) + + +def apply_flow_result_patch(result: FlowResult, patch: dict[str, bool] | None) -> None: + if patch is None: + return + + for key, value in patch.items(): + setattr(result, key, value) + + +def resolve_resume_index(steps: tuple[StepSpec, ...], phase: resume_state.ResumePhase) -> int: + for index, step in enumerate(steps): + if step.phase == phase: + return index + + raise ResumeFlowException(f"Unsupported resume phase: {phase}") + + +def _should_run_hook_subset(ctx: FlowContext) -> bool: + return ctx.needs_rebase or ctx.always_run_hooks + + +def _should_run_rebase_only(ctx: FlowContext) -> bool: + return ctx.needs_rebase + + +def _require_hooks_when_no_rebase(ctx: FlowContext) -> None: + if not ctx.needs_rebase and not ctx.always_run_hooks: + raise ResumeFlowException(_NO_REBASE_CONTINUE_MESSAGE) + + +def run_hook_step( + *, + ctx: FlowContext, + hook: lifecycle_hooks.LifecycleHook, + phase: resume_state.ResumePhase, + start_script_index: int = 0, + remaining_tasks: list[resume_state.ResumeTask] | None = None, + art_tasks: list[resume_state.ResumeTask] | None = None, +) -> None: + execute_hook_with_resume( + hook=hook, + hooks=ctx.hooks, + phase=phase, + remaining_tasks=remaining_tasks, + art_tasks=art_tasks, + start_script_index=start_script_index, + **ctx.flow_args(), + ) + + +def run_task_step( + *, + ctx: FlowContext, + actions: FlowActions, + tasks: list[resume_state.ResumeTask], + phase: resume_state.ResumePhase, + future_art_tasks: list[resume_state.ResumeTask] | None = None, +) -> None: + actions.execute_rebase_tasks( + tasks=tasks, + phase=phase, + conflict_policy=ctx.conflict_policy, + pause_on_conflict=ctx.effective_pause_on_conflict(), + future_art_tasks=future_art_tasks, + **ctx.flow_args(), + ) + + +def _run_pre_rebase_fresh(ctx: FlowContext, _actions: FlowActions) -> dict[str, bool] | None: + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.PRE_REBASE, + phase=resume_state.ResumePhase.PRE_REBASE, + ) + return None + + +def _run_pre_rebase_resume( + ctx: FlowContext, + _actions: FlowActions, + state: resume_state.ResumeState, +) -> dict[str, bool] | None: + _require_hooks_when_no_rebase(ctx) + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.PRE_REBASE, + phase=resume_state.ResumePhase.PRE_REBASE, + start_script_index=resolve_hook_resume_index( + hook=lifecycle_hooks.LifecycleHook.PRE_REBASE, + hooks=ctx.hooks, + state=state, + retry_failed_step=ctx.retry_failed_step, + ), + ) + return None + + +def _run_pre_carry_commit_fresh(ctx: FlowContext, actions: FlowActions) -> dict[str, bool] | None: + if ctx.needs_rebase: + actions.prepare_rebase_branch(ctx.gitwd, ctx.source, ctx.dest) + + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.PRE_CARRY_COMMIT, + phase=resume_state.ResumePhase.PRE_CARRY_COMMIT, + ) + return None + + +def _run_pre_carry_commit_resume( + ctx: FlowContext, + actions: FlowActions, + state: resume_state.ResumeState, +) -> dict[str, bool] | None: + _require_hooks_when_no_rebase(ctx) + if ctx.needs_rebase: + actions.prepare_rebase_branch(ctx.gitwd, ctx.source, ctx.dest) + + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.PRE_CARRY_COMMIT, + phase=resume_state.ResumePhase.PRE_CARRY_COMMIT, + start_script_index=resolve_hook_resume_index( + hook=lifecycle_hooks.LifecycleHook.PRE_CARRY_COMMIT, + hooks=ctx.hooks, + state=state, + retry_failed_step=ctx.retry_failed_step, + ), + ) + return None + + +def _run_carry_commits_fresh(ctx: FlowContext, actions: FlowActions) -> dict[str, bool] | None: + art_tasks = actions.build_art_pr_tasks(ctx.dest_repo, ctx.dest, ctx.gitwd) + ctx.set_runtime_art_tasks(art_tasks) + carry_tasks = actions.build_rebase_tasks( + gitwd=ctx.gitwd, + source=ctx.source, + dest=ctx.dest, + source_repo=ctx.source_repo, + tag_policy=ctx.tag_policy, + bot_emails=ctx.bot_emails, + exclude_commits=ctx.exclude_commits, + update_go_modules=ctx.update_go_modules, + ) + run_task_step( + ctx=ctx, + actions=actions, + tasks=carry_tasks, + phase=resume_state.ResumePhase.CARRY_COMMITS, + future_art_tasks=art_tasks, + ) + return None + + +def _run_carry_commits_resume( + ctx: FlowContext, + actions: FlowActions, + state: resume_state.ResumeState, +) -> dict[str, bool] | None: + ctx.set_runtime_art_tasks(state.art_tasks) + run_task_step( + ctx=ctx, + actions=actions, + tasks=state.remaining_tasks, + phase=resume_state.ResumePhase.CARRY_COMMITS, + future_art_tasks=state.art_tasks, + ) + return None + + +def _run_post_rebase_fresh(ctx: FlowContext, _actions: FlowActions) -> dict[str, bool] | None: + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.POST_REBASE, + phase=resume_state.ResumePhase.POST_REBASE, + art_tasks=ctx.current_art_tasks(), + ) + return None + + +def _run_post_rebase_resume( + ctx: FlowContext, + _actions: FlowActions, + state: resume_state.ResumeState, +) -> dict[str, bool] | None: + ctx.set_runtime_art_tasks(state.art_tasks) + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.POST_REBASE, + phase=resume_state.ResumePhase.POST_REBASE, + art_tasks=ctx.current_art_tasks(), + start_script_index=resolve_hook_resume_index( + hook=lifecycle_hooks.LifecycleHook.POST_REBASE, + hooks=ctx.hooks, + state=state, + retry_failed_step=ctx.retry_failed_step, + ), + ) + return None + + +def _run_art_pr_fresh(ctx: FlowContext, actions: FlowActions) -> dict[str, bool] | None: + run_task_step( + ctx=ctx, + actions=actions, + tasks=ctx.current_art_tasks(), + phase=resume_state.ResumePhase.ART_PR, + ) + return None + + +def _run_art_pr_resume( + ctx: FlowContext, + actions: FlowActions, + state: resume_state.ResumeState, +) -> dict[str, bool] | None: + ctx.set_runtime_art_tasks(state.art_tasks) + run_task_step( + ctx=ctx, + actions=actions, + tasks=state.remaining_tasks, + phase=resume_state.ResumePhase.ART_PR, + ) + return None + + +def build_core_steps() -> tuple[StepSpec, ...]: + return ( + StepSpec( + phase=resume_state.ResumePhase.PRE_REBASE, + fresh_when=_should_run_hook_subset, + run_fresh=_run_pre_rebase_fresh, + run_resume=_run_pre_rebase_resume, + ), + StepSpec( + phase=resume_state.ResumePhase.PRE_CARRY_COMMIT, + fresh_when=_should_run_hook_subset, + run_fresh=_run_pre_carry_commit_fresh, + run_resume=_run_pre_carry_commit_resume, + ), + StepSpec( + phase=resume_state.ResumePhase.CARRY_COMMITS, + fresh_when=_should_run_rebase_only, + run_fresh=_run_carry_commits_fresh, + run_resume=_run_carry_commits_resume, + ), + StepSpec( + phase=resume_state.ResumePhase.POST_REBASE, + fresh_when=_should_run_hook_subset, + run_fresh=_run_post_rebase_fresh, + run_resume=_run_post_rebase_resume, + ), + StepSpec( + phase=resume_state.ResumePhase.ART_PR, + fresh_when=_should_run_rebase_only, + run_fresh=_run_art_pr_fresh, + run_resume=_run_art_pr_resume, + terminal_on_resume=True, + ), + ) + + +def _continue_publish_hook_phase(ctx: FlowContext) -> FlowResult | None: + if ctx.resume is None: + return None + + if ctx.resume.phase == resume_state.ResumePhase.PRE_PUSH_REBASE_BRANCH: + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.PRE_PUSH_REBASE_BRANCH, + phase=resume_state.ResumePhase.PRE_PUSH_REBASE_BRANCH, + start_script_index=resolve_hook_resume_index( + hook=lifecycle_hooks.LifecycleHook.PRE_PUSH_REBASE_BRANCH, + hooks=ctx.hooks, + state=ctx.resume, + retry_failed_step=ctx.retry_failed_step, + ), + ) + return FlowResult(needs_rebase=ctx.needs_rebase, skip_pre_push_rebase_branch_hook=True) + + if ctx.resume.phase == resume_state.ResumePhase.PRE_CREATE_PR: + run_hook_step( + ctx=ctx, + hook=lifecycle_hooks.LifecycleHook.PRE_CREATE_PR, + phase=resume_state.ResumePhase.PRE_CREATE_PR, + start_script_index=resolve_hook_resume_index( + hook=lifecycle_hooks.LifecycleHook.PRE_CREATE_PR, + hooks=ctx.hooks, + state=ctx.resume, + retry_failed_step=ctx.retry_failed_step, + ), + ) + return FlowResult(needs_rebase=ctx.needs_rebase, skip_pre_create_pr_hook=True) + + return None + + +def _execute_core_flow(ctx: FlowContext, actions: FlowActions) -> FlowResult: + steps = build_core_steps() + result = FlowResult(needs_rebase=ctx.needs_rebase) + + if ctx.resume is None: + for step in steps: + if not step.fresh_when(ctx): + continue + apply_flow_result_patch(result, step.run_fresh(ctx, actions)) + return result + + resume_index = resolve_resume_index(steps, ctx.resume.phase) + for offset, step in enumerate(steps[resume_index:], start=resume_index): + if offset == resume_index: + patch = step.run_resume(ctx, actions, ctx.resume) + elif step.fresh_when(ctx): + patch = step.run_fresh(ctx, actions) + else: + continue + + apply_flow_result_patch(result, patch) + if step.terminal_on_resume: + return result + + return result + + +def execute_flow( # pylint: disable=too-many-arguments,too-many-positional-arguments + *, + gitwd: git.Repo, + source: GitHubBranch, + dest: GitHubBranch, + rebase: GitHubBranch, + working_dir: str, + source_repo: Repository, + dest_repo: Repository, + hooks: lifecycle_hooks.LifecycleHooks, + tag_policy: str, + conflict_policy: str, + bot_emails: list, + exclude_commits: list, + update_go_modules: bool, + always_run_hooks: bool, + pause_on_conflict: bool, + retry_failed_step: bool, + actions: FlowActions, + state: resume_state.ResumeState | None = None, +) -> FlowResult: + """Execute a fresh or resumed rebase flow and return publish-step metadata.""" + ctx = build_flow_context( + gitwd=gitwd, + source=source, + dest=dest, + rebase=rebase, + working_dir=working_dir, + source_repo=source_repo, + dest_repo=dest_repo, + hooks=hooks, + tag_policy=tag_policy, + conflict_policy=conflict_policy, + bot_emails=bot_emails, + exclude_commits=exclude_commits, + update_go_modules=update_go_modules, + always_run_hooks=always_run_hooks, + pause_on_conflict=pause_on_conflict, + retry_failed_step=retry_failed_step, + actions=actions, + state=state, + ) + + publish_result = _continue_publish_hook_phase(ctx) + if publish_result is not None: + return publish_result + + return _execute_core_flow(ctx, actions) diff --git a/rebasebot/resume_state.py b/rebasebot/resume_state.py new file mode 100644 index 0000000..530fa12 --- /dev/null +++ b/rebasebot/resume_state.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import json +import os +from dataclasses import asdict, dataclass +from enum import Enum + +from rebasebot.github import GitHubBranch + +STATE_FILENAME = ".rebasebot-resume.json" +STATE_VERSION = 4 + + +class ResumeStateError(ValueError): + """Raised when persisted resume state cannot be loaded or validated.""" + + +class ResumePhase(str, Enum): + PRE_REBASE = "pre_rebase" + PRE_CARRY_COMMIT = "pre_carry_commit" + CARRY_COMMITS = "carry_commits" + POST_REBASE = "post_rebase" + ART_PR = "art_pr" + PRE_PUSH_REBASE_BRANCH = "pre_push_rebase_branch" + PRE_CREATE_PR = "pre_create_pr" + + @property + def display_name(self) -> str: + return { + ResumePhase.PRE_REBASE: "pre-rebase hook", + ResumePhase.PRE_CARRY_COMMIT: "pre-carry hook", + ResumePhase.CARRY_COMMITS: "carry commits", + ResumePhase.POST_REBASE: "post-rebase hook", + ResumePhase.ART_PR: "ART PR commits", + ResumePhase.PRE_PUSH_REBASE_BRANCH: "pre-push hook", + ResumePhase.PRE_CREATE_PR: "pre-create-PR hook", + }[self] + + +@dataclass +class BranchState: + url: str + ns: str + name: str + branch: str + + @classmethod + def from_github_branch(cls, branch: GitHubBranch) -> BranchState: + return cls(url=branch.url, ns=branch.ns, name=branch.name, branch=branch.branch) + + def to_github_branch(self) -> GitHubBranch: + return GitHubBranch(url=self.url, ns=self.ns, name=self.name, branch=self.branch) + + +@dataclass +class ResumeTask: + kind: str + sha: str | None = None + source_branch: str | None = None + commit_description: str | None = None + commit_message: str | None = None + author: str | None = None + reset_count: int | None = None + + @classmethod + def from_dict(cls, payload: dict) -> ResumeTask: + return cls(**payload) + + +@dataclass +class ResumeState: + source: BranchState + dest: BranchState + rebase: BranchState + source_head_sha: str + dest_head_sha: str + phase: ResumePhase + remaining_tasks: list[ResumeTask] + art_tasks: list[ResumeTask] + current_task: ResumeTask | None = None + head_before_task: str | None = None + head_at_pause: str | None = None + allowed_untracked_files: list[str] | None = None + next_hook_script_index: int | None = None + hook_script_locations: list[str] | None = None + version: int = STATE_VERSION + + @classmethod + def from_dict(cls, payload: dict) -> ResumeState: + if payload.get("version") != STATE_VERSION: + raise ResumeStateError(f"Unsupported resume state version: {payload.get('version')}") + + try: + return cls( + source=BranchState(**payload["source"]), + dest=BranchState(**payload["dest"]), + rebase=BranchState(**payload["rebase"]), + source_head_sha=payload["source_head_sha"], + dest_head_sha=payload["dest_head_sha"], + phase=ResumePhase(payload["phase"]), + remaining_tasks=[ResumeTask.from_dict(task) for task in payload["remaining_tasks"]], + art_tasks=[ResumeTask.from_dict(task) for task in payload.get("art_tasks", [])], + current_task=ResumeTask.from_dict(payload["current_task"]) + if payload["current_task"] is not None + else None, + head_before_task=payload.get("head_before_task"), + head_at_pause=payload.get("head_at_pause"), + allowed_untracked_files=payload.get("allowed_untracked_files"), + next_hook_script_index=payload.get("next_hook_script_index"), + hook_script_locations=payload.get("hook_script_locations"), + version=payload["version"], + ) + except KeyError as err: + raise ResumeStateError(f"Resume state is missing field: {err.args[0]}") from err + except TypeError as err: + raise ResumeStateError("Resume state contains invalid data") from err + + +def resume_state_path(workdir: str) -> str: + return os.path.join(workdir, STATE_FILENAME) + + +def has_resume_state(workdir: str) -> bool: + return os.path.exists(resume_state_path(workdir)) + + +def write_resume_state(workdir: str, state: ResumeState) -> str: + path = resume_state_path(workdir) + payload = asdict(state) + temp_path = f"{path}.tmp" + with open(temp_path, "w", encoding="utf-8") as state_file: + json.dump(payload, state_file, indent=2, sort_keys=True) + os.replace(temp_path, path) + return path + + +def read_resume_state(workdir: str) -> ResumeState: + path = resume_state_path(workdir) + try: + with open(path, encoding="utf-8") as state_file: + payload = json.load(state_file) + except FileNotFoundError as err: + raise ResumeStateError(f"No resume state found in {workdir}") from err + except json.JSONDecodeError as err: + raise ResumeStateError(f"Resume state in {path} is not valid JSON") from err + + return ResumeState.from_dict(payload) + + +def clear_resume_state(workdir: str) -> None: + path = resume_state_path(workdir) + if os.path.exists(path): + os.remove(path) diff --git a/tests/rebase_test_support.py b/tests/rebase_test_support.py new file mode 100644 index 0000000..d67329b --- /dev/null +++ b/tests/rebase_test_support.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from unittest.mock import MagicMock + +from git import Repo + +from rebasebot.github import GitHubBranch + +from .conftest import CommitBuilder + + +@dataclass +class WorkingRepoContext: + source: GitHubBranch + rebase: GitHubBranch + dest: GitHubBranch + + working_repo: Repo + working_repo_path: str + + def fetch_remotes(self): + self.working_repo.git.fetch("--all") + + +def make_rebasebot_args( + *, + source, + dest, + rebase, + working_dir, + **overrides, +): + defaults = { + "source": source, + "source_repo": None, + "dest": dest, + "rebase": rebase, + "working_dir": working_dir, + "git_username": "test_rebasebot", + "git_email": "test@rebasebot.ocp", + "tag_policy": "soft", + "bot_emails": [], + "exclude_commits": [], + "update_go_modules": False, + "conflict_policy": "auto", + "ignore_manual_label": False, + "dry_run": False, + "pause_on_conflict": False, + "continue_run": False, + "retry_failed_step": False, + "always_run_hooks": False, + "title_prefix": "", + "pre_rebase_hook": None, + "post_rebase_hook": None, + "pre_carry_commit_hook": None, + "pre_push_rebase_branch_hook": None, + "pre_create_pr_hook": None, + } + defaults.update(overrides) + args = MagicMock() + for key, value in defaults.items(): + setattr(args, key, value) + return args + + +class FakeArtCommit: + def __init__(self, sha: str): + self.sha = sha + + +class FakeArtPullRequest: + def __init__(self, sha: str, art_repo_dir: str, branch: str): + self.title = "update image consistent with ART" + self.user = MagicMock() + self.user.login = "openshift-bot" + repository = MagicMock() + repository.name = "art-remote" + repository.html_url = art_repo_dir + self.head = MagicMock() + self.head.repository = repository + self.head.ref = branch + self.labels = [] + self._sha = sha + + def commits(self): + return [FakeArtCommit(self._sha)] + + +def setup_fake_art_pr(fake_github_provider, source, dest, rebase, art_repo_dir): + dest_repo = MagicMock() + dest_repo.clone_url = dest.url + source_repo = MagicMock() + source_repo.clone_url = source.url + rebase_repo = MagicMock() + rebase_repo.clone_url = rebase.url + + Repo.init(art_repo_dir) + art_branch = GitHubBranch(url=art_repo_dir, ns="art", name="art", branch="art-branch") + art_base_branch = GitHubBranch(url=art_repo_dir, ns="art", name="art", branch="master") + CommitBuilder(art_base_branch).add_file("art-shared.txt", "base art version\n").commit("ART base") + art_commit = CommitBuilder(art_branch).move_file("art-shared.txt", "art-side.txt").commit("ART conflicting commit") + art_pr = FakeArtPullRequest(art_commit.hexsha, art_repo_dir, art_branch.branch) + + def pull_requests(*args, **kwargs): + if kwargs.get("state") == "open" and kwargs.get("base") == dest.branch: + return [art_pr] + return [] + + dest_repo.pull_requests.side_effect = pull_requests + fake_github_provider.github_app.repository.side_effect = lambda ns, name: { + dest.name: dest_repo, + source.name: source_repo, + }[name] + fake_github_provider.github_cloner_app.repository.side_effect = lambda ns, name: {rebase.name: rebase_repo}[name] + return dest_repo, art_commit + + +def write_hook_script(script_dir: str, name: str, content: str) -> str: + path = os.path.join(script_dir, name) + with open(path, "w", encoding="utf-8") as hook_file: + hook_file.write(content) + os.chmod(path, 0o700) + return path diff --git a/tests/test_cli.py b/tests/test_cli.py index b90d57c..e71d61c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -18,10 +18,13 @@ import pytest +from rebasebot import bot, cli, resume_state from rebasebot.cli import _parse_cli_arguments from rebasebot.cli import main as cli_main from rebasebot.github import GitHubBranch +from .rebase_test_support import make_rebasebot_args + def args_dict_to_list(args_dict: dict) -> list[str]: args = [] @@ -141,6 +144,15 @@ def test_working_dir_falls_back_when_xdg_cache_home_empty(self, valid_args_dict) parsed_args = _parse_cli_arguments() assert parsed_args.working_dir == expected + def test_pause_continue_and_retry_flags_parse(self, get_valid_cli_args): + args = get_valid_cli_args({"pause-on-conflict": None, "continue": None, "retry-failed-step": None}) + with patch("sys.argv", ["rebasebot", *args]): + parsed_args = _parse_cli_arguments() + + assert parsed_args.pause_on_conflict is True + assert parsed_args.continue_run is True + assert parsed_args.retry_failed_step is True + @patch("rebasebot.bot.run") def test_no_credentials_arg(self, mocked_run, valid_args_dict, capsys): args_dict = valid_args_dict @@ -240,3 +252,82 @@ def _mocked_run(**kwargs): passed_working_dir = mocked_run.call_args.kwargs.get("working_dir") assert passed_working_dir == os.path.join(xdg_cache, "rebasebot") assert os.path.isdir(passed_working_dir) + + @patch("rebasebot.cli._get_github_app_wrapper") + @patch("rebasebot.cli.rebasebot_run") + def test_main_exits_with_paused_status( + self, + mocked_rebasebot_run, + mocked_get_github_app_wrapper, + get_valid_cli_args, + ): + mocked_get_github_app_wrapper.return_value = MagicMock() + mocked_rebasebot_run.side_effect = bot.PausedRebaseException("paused") + + args = get_valid_cli_args() + with patch("sys.argv", ["rebasebot", *args]): + with pytest.raises(SystemExit) as exit_exc: + cli_main() + + assert exit_exc.value.code == 3 + + @patch("rebasebot.lifecycle_hooks.LifecycleHooks") + @patch("rebasebot.lifecycle_hooks.run_source_repo_hook") + @patch("rebasebot.bot.run") + def test_continue_uses_persisted_source_without_rerunning_source_hook( + self, + mocked_run, + mocked_source_hook, + mocked_lifecycle_hooks, + tmp_path, + ): + working_dir = tmp_path / "working-dir" + working_dir.mkdir() + persisted_source = GitHubBranch( + url="https://github.com/source/source", + ns="source", + name="source", + branch="main", + ) + dest = GitHubBranch(url="https://github.com/dest/dest", ns="dest", name="dest", branch="main") + rebase = GitHubBranch(url="https://github.com/rebase/rebase", ns="rebase", name="rebase", branch="main") + state = resume_state.ResumeState( + source=resume_state.BranchState.from_github_branch(persisted_source), + dest=resume_state.BranchState.from_github_branch(dest), + rebase=resume_state.BranchState.from_github_branch(rebase), + source_head_sha="a" * 40, + dest_head_sha="b" * 40, + phase=resume_state.ResumePhase.CARRY_COMMITS, + remaining_tasks=[], + art_tasks=[], + current_task=resume_state.ResumeTask(kind="pick", sha="c" * 40, commit_description="paused task"), + head_before_task="d" * 40, + allowed_untracked_files=[], + ) + resume_state.write_resume_state(str(working_dir), state) + + args = make_rebasebot_args( + source=None, + dest=dest, + rebase=rebase, + working_dir=str(working_dir), + source_repo="source/source", + source_ref_hook="git:https://github.com/source/source/main:hook.sh", + git_username="test", + git_email="test@example.com", + dry_run=True, + pause_on_conflict=True, + continue_run=True, + retry_failed_step=True, + ) + + mocked_lifecycle_hooks.return_value = MagicMock() + mocked_run.return_value = True + + result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=MagicMock()) + + assert result is True + mocked_source_hook.assert_not_called() + assert args.source == persisted_source + assert mocked_run.call_args.kwargs["source"] == persisted_source + assert mocked_run.call_args.kwargs["retry_failed_step"] is True diff --git a/tests/test_rebase_resume_conflicts.py b/tests/test_rebase_resume_conflicts.py new file mode 100644 index 0000000..bd817e8 --- /dev/null +++ b/tests/test_rebase_resume_conflicts.py @@ -0,0 +1,537 @@ +from __future__ import annotations + +import logging +import os +from tempfile import TemporaryDirectory +from unittest.mock import patch + +import pytest +from git import Repo + +from rebasebot import bot, cli, resume_state + +from .conftest import CommitBuilder +from .rebase_test_support import ( + FakeArtCommit, + FakeArtPullRequest, + make_rebasebot_args, + setup_fake_art_pr, +) + + +_ORIGINAL_CODE = """\ +package main + +const ( +\tregionKey = "region" +\tebsCSIDriver = "ebs.csi.aws.com" +) + +type Snapshotter struct { +\tlog string +\tec2 string +} +""" + +_UPSTREAM_ADDED_CODE = """\ +package main + +const ( +\tregionKey = "region" +\tebsKmsKeyIDKey = "ebsKmsKeyId" +\tebsCSIDriver = "ebs.csi.aws.com" +) + +type Snapshotter struct { +\tlog string +\tec2 string +\tebsKmsKeyId string +} +""" + +_DOWNSTREAM_CARRY_CODE = """\ +package main + +const ( +\tregionKey = "region" +\tebsCSIDriver = "ebs.csi.aws.com" +\tsnapshotCreationTimeoutKey = "snapshotCreationTimeout" +) + +type Snapshotter struct { +\tlog string +\tec2 string +\tsnapshotCreationTimeout string +} +""" + +_MERGED_CODE = """\ +package main + +const ( +\tregionKey = "region" +\tebsKmsKeyIDKey = "ebsKmsKeyId" +\tebsCSIDriver = "ebs.csi.aws.com" +\tsnapshotCreationTimeoutKey = "snapshotCreationTimeout" +) + +type Snapshotter struct { +\tlog string +\tec2 string +\tebsKmsKeyId string +\tsnapshotCreationTimeout string +} +""" + + +def _set_up_strict_content_loss_history(source, dest, *, add_later_carry: bool = False) -> None: + CommitBuilder(source).update_file("test.go", _ORIGINAL_CODE).commit("set up base code") + CommitBuilder(dest).update_file("test.go", _ORIGINAL_CODE).commit("UPSTREAM: : sync base") + CommitBuilder(source).update_file("test.go", _UPSTREAM_ADDED_CODE).commit("Add KMS key support") + CommitBuilder(dest).update_file("test.go", _DOWNSTREAM_CARRY_CODE).commit( + "UPSTREAM: : add snapshot timeout" + ) + if add_later_carry: + CommitBuilder(dest).add_file("later-carry.txt", "later content\n").commit("UPSTREAM: : later carry") + + +class TestRebaseResumeConflicts: + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_pause_and_continue_after_strict_content_loss_resolution( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/strict" + _set_up_strict_content_loss_history(source, dest, add_later_carry=True) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + conflict_policy="strict", + pause_on_conflict=True, + dry_run=True, + ) + + with pytest.raises(bot.PausedRebaseException) as paused_exc: + cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + assert "Upstream content was lost" in str(paused_exc.value) + state = resume_state.read_resume_state(tmpdir) + assert state.phase == resume_state.ResumePhase.CARRY_COMMITS + assert state.current_task.commit_description.endswith("UPSTREAM: : sync base") + assert state.head_at_pause is not None + assert state.head_at_pause != state.head_before_task + + working_repo = Repo(tmpdir) + with open(os.path.join(tmpdir, "test.go"), "w", encoding="utf-8") as merged_file: + merged_file.write(_UPSTREAM_ADDED_CODE) + working_repo.git.add("test.go") + working_repo.git.commit("--amend", "--no-edit", "--allow-empty") + + args.continue_run = True + with pytest.raises(bot.PausedRebaseException) as second_paused_exc: + cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + assert "Upstream content was lost" in str(second_paused_exc.value) + state = resume_state.read_resume_state(tmpdir) + assert state.current_task.commit_description.endswith("UPSTREAM: : add snapshot timeout") + assert state.head_at_pause is not None + assert state.head_at_pause != state.head_before_task + + with open(os.path.join(tmpdir, "test.go"), "w", encoding="utf-8") as merged_file: + merged_file.write(_MERGED_CODE) + working_repo.git.add("test.go") + working_repo.git.commit("--amend", "--no-edit") + + result = cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + assert result is True + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + assert os.path.exists(os.path.join(tmpdir, "later-carry.txt")) + with open(os.path.join(tmpdir, "test.go"), encoding="utf-8") as merged_file: + merged_contents = merged_file.read() + assert merged_contents == _MERGED_CODE + assert "Upstream content was lost" in mocked_message_slack.call_args_list[0].args[1] + assert "Upstream content was lost" in mocked_message_slack.call_args_list[1].args[1] + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_continue_rejects_unchanged_commit_after_strict_content_loss_pause( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + caplog, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/strict-unchanged" + _set_up_strict_content_loss_history(source, dest) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + conflict_policy="strict", + pause_on_conflict=True, + dry_run=True, + ) + + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + args.continue_run = True + with caplog.at_level(logging.ERROR): + result = cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + assert result is False + assert os.path.exists(resume_state.resume_state_path(tmpdir)) + assert mocked_message_slack.call_count == 2 + assert "Failure reason: Cannot continue paused run because the paused commit was not changed after the pause." in caplog.text + assert "paused commit was not changed after the pause" in mocked_message_slack.call_args_list[-1].args[1] + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_pause_and_continue_after_manual_conflict_resolution( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/1" + + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + CommitBuilder(dest).add_file("later-carry.txt", "later content\n").commit("UPSTREAM: : later carry") + with CommitBuilder(dest) as cb: + cb.add_file( + "pre-rebase-hook-script.sh", + "#!/bin/bash\nset -eu\nprintf 'pre\\n' >> pre-rebase-hook.log\n", + ) + cb.add_file( + "post-rebase-hook-script.sh", + "#!/bin/bash\nset -eu\ntouch post-rebase-hook.success\n", + ) + cb.commit("UPSTREAM: : add hook scripts") + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + pre_rebase_hook=[f"git:dest/{dest.branch}:pre-rebase-hook-script.sh"], + post_rebase_hook=[f"git:dest/{dest.branch}:post-rebase-hook-script.sh"], + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException) as paused_exc: + cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + assert "Paused during carry commits" in str(paused_exc.value) + state = resume_state.read_resume_state(tmpdir) + assert state.phase == resume_state.ResumePhase.CARRY_COMMITS + assert state.current_task.commit_description.endswith("UPSTREAM: : downstream conflict") + assert mocked_push_rebase_branch.call_count == 0 + assert mocked_create_pr.call_count == 0 + assert os.path.exists(os.path.join(tmpdir, "pre-rebase-hook.log")) + assert not os.path.exists(os.path.join(tmpdir, "post-rebase-hook.success")) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + + args.continue_run = True + result = cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + assert result is True + assert mocked_push_rebase_branch.call_count == 1 + assert mocked_create_pr.call_count == 1 + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + assert os.path.exists(os.path.join(tmpdir, "post-rebase-hook.success")) + with open(os.path.join(tmpdir, "pre-rebase-hook.log"), encoding="utf-8") as hook_log: + assert hook_log.read().splitlines() == ["pre"] + assert os.path.exists(os.path.join(tmpdir, "later-carry.txt")) + assert os.path.exists(os.path.join(tmpdir, "dest-test.go")) + assert mocked_message_slack.call_args_list[0].args[1].startswith("Paused during carry commits") + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_pause_and_continue_during_art_preserves_post_rebase_boundary( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/2" + + CommitBuilder(source).add_file("art-shared.txt", "base art version\n").commit("add art base") + CommitBuilder(source).move_file("art-shared.txt", "upstream-art.txt").commit("upstream art conflict") + with CommitBuilder(dest) as cb: + cb.add_file( + "post-rebase-hook-script.sh", + "#!/bin/bash\nset -eu\nprintf 'post\\n' >> post-rebase-hook.log\n", + ) + cb.commit("UPSTREAM: : add post hook script") + + with TemporaryDirectory(prefix="rebasebot_tests_art_repo_") as art_repo_dir: + setup_fake_art_pr(fake_github_provider, source, dest, rebase, art_repo_dir) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + post_rebase_hook=[f"git:dest/{dest.branch}:post-rebase-hook-script.sh"], + ) + + with ( + patch("rebasebot.bot.ShortPullRequest", FakeArtPullRequest), + patch("rebasebot.bot.ShortCommit", FakeArtCommit), + ): + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + state = resume_state.read_resume_state(tmpdir) + assert state.phase == resume_state.ResumePhase.ART_PR + with open(os.path.join(tmpdir, "post-rebase-hook.log"), encoding="utf-8") as hook_log: + assert hook_log.read().splitlines() == ["post"] + + working_repo = Repo(tmpdir) + working_repo.git.rm("art-shared.txt", "upstream-art.txt") + working_repo.git.add("art-side.txt") + working_repo.git.cherry_pick("--continue") + + args.continue_run = True + result = cli.rebasebot_run( + args, + slack_webhook="test://webhook", + github_app_wrapper=fake_github_provider, + ) + + assert result is True + assert mocked_push_rebase_branch.call_count == 1 + assert mocked_create_pr.call_count == 1 + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + with open(os.path.join(tmpdir, "post-rebase-hook.log"), encoding="utf-8") as hook_log: + assert hook_log.read().splitlines() == ["post"] + assert mocked_message_slack.call_args_list[0].args[1].startswith("Paused during ART PR commits") + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_continue_can_pause_again_for_later_conflict( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/3" + + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + CommitBuilder(source).add_file("art-shared.txt", "base art version\n").commit("add art base") + CommitBuilder(source).move_file("art-shared.txt", "upstream-art.txt").commit("upstream art conflict") + + with TemporaryDirectory(prefix="rebasebot_tests_art_repo_") as art_repo_dir: + setup_fake_art_pr(fake_github_provider, source, dest, rebase, art_repo_dir) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + ) + + with ( + patch("rebasebot.bot.ShortPullRequest", FakeArtPullRequest), + patch("rebasebot.bot.ShortCommit", FakeArtCommit), + patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False), + ): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + + args.continue_run = True + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + state = resume_state.read_resume_state(tmpdir) + assert state.phase == resume_state.ResumePhase.ART_PR + assert state.current_task.commit_description.startswith("ART PR commit ") + + working_repo = Repo(tmpdir) + working_repo.git.rm("art-shared.txt", "upstream-art.txt") + working_repo.git.add("art-side.txt") + working_repo.git.cherry_pick("--continue") + + args.continue_run = True + result = cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) + + assert result is True + assert mocked_push_rebase_branch.call_count == 1 + assert mocked_create_pr.call_count == 1 + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + assert mocked_message_slack.call_args_list[1].args[1].startswith("Paused during ART PR commits") + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_continue_uses_snapshotted_art_tasks_from_pause( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/5" + + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + CommitBuilder(source).add_file("art-shared.txt", "base art version\n").commit("add art base") + CommitBuilder(source).move_file("art-shared.txt", "upstream-art.txt").commit("upstream art conflict") + + with TemporaryDirectory(prefix="rebasebot_tests_art_repo_") as art_repo_dir: + dest_repo, _ = setup_fake_art_pr(fake_github_provider, source, dest, rebase, art_repo_dir) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + ) + + with ( + patch("rebasebot.bot.ShortPullRequest", FakeArtPullRequest), + patch("rebasebot.bot.ShortCommit", FakeArtCommit), + patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False), + ): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + + dest_repo.pull_requests.side_effect = lambda *args, **kwargs: [] + + args.continue_run = True + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + art_resume_state = resume_state.read_resume_state(tmpdir) + assert art_resume_state.phase == resume_state.ResumePhase.ART_PR + assert art_resume_state.current_task is not None + assert art_resume_state.current_task.commit_description.startswith("ART PR commit ") + assert mocked_message_slack.call_args_list[-1].args[1].startswith("Paused during ART PR commits") + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_continue_allows_cherry_pick_skip( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/6" + + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + CommitBuilder(dest).add_file("later-carry.txt", "later content\n").commit("UPSTREAM: : later carry") + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.cherry_pick("--skip") + + args.continue_run = True + result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert result is True + assert os.path.exists(os.path.join(tmpdir, "later-carry.txt")) + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + assert mocked_push_rebase_branch.call_count == 1 + assert mocked_create_pr.call_count == 1 diff --git a/tests/test_rebase_resume_hooks.py b/tests/test_rebase_resume_hooks.py new file mode 100644 index 0000000..da0989e --- /dev/null +++ b/tests/test_rebase_resume_hooks.py @@ -0,0 +1,494 @@ +from __future__ import annotations + +import os +from tempfile import TemporaryDirectory +from unittest.mock import patch + +import pytest +from git import Repo + +from rebasebot import bot, cli, resume_state + +from .conftest import CommitBuilder +from .rebase_test_support import make_rebasebot_args, write_hook_script + + +class TestRebaseResumeHooks: + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_continue_skips_failed_post_rebase_hook_script( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/7" + + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + CommitBuilder(dest).add_file("later-carry.txt", "later content\n").commit("UPSTREAM: : later carry") + + with TemporaryDirectory(prefix="rebasebot_post_hook_") as hook_dir: + hook_log = os.path.join(hook_dir, "post-rebase-hook.log") + fail_hook_path = write_hook_script( + hook_dir, + "post-rebase-fail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'fail\\n' >> {hook_log}\n" + "exit 7\n", + ) + tail_hook_path = write_hook_script( + hook_dir, + "post-rebase-tail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'tail\\n' >> {hook_log}\n", + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + post_rebase_hook=[fail_hook_path, tail_hook_path], + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + + args.continue_run = True + first_retry_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert first_retry_result is False + post_rebase_resume_state = resume_state.read_resume_state(tmpdir) + assert post_rebase_resume_state.phase == resume_state.ResumePhase.POST_REBASE + assert post_rebase_resume_state.remaining_tasks == [] + assert post_rebase_resume_state.art_tasks == [] + assert post_rebase_resume_state.next_hook_script_index == 1 + assert post_rebase_resume_state.hook_script_locations == [fail_hook_path, tail_hook_path] + first_retry_log = working_repo.git.log("--oneline", "--grep", "later carry") + first_retry_count = len(first_retry_log.splitlines()) + + second_retry_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert second_retry_result is True + assert os.path.exists(os.path.join(tmpdir, "later-carry.txt")) + log_output = working_repo.git.log("--oneline", "--grep", "later carry") + assert len(log_output.splitlines()) == first_retry_count + with open(hook_log, encoding="utf-8") as logged_hook: + assert logged_hook.read().splitlines() == ["fail", "tail"] + assert "failed with exit-code 7" in mocked_message_slack.call_args_list[-2].args[1] + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._message_slack") + def test_continue_retries_failed_post_rebase_hook_script( + self, + mocked_message_slack, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/7-retry" + + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + CommitBuilder(dest).add_file("later-carry.txt", "later content\n").commit("UPSTREAM: : later carry") + + with TemporaryDirectory(prefix="rebasebot_post_hook_retry_") as hook_dir: + hook_log = os.path.join(hook_dir, "post-rebase-hook.log") + fail_marker = os.path.join(hook_dir, "fail-marker") + with open(fail_marker, "w", encoding="utf-8") as marker_file: + marker_file.write("fail once\n") + + retry_hook_path = write_hook_script( + hook_dir, + "post-rebase-retry.sh", + "#!/bin/bash\n" + "set -eu\n" + f"if [ -f {fail_marker} ]; then\n" + f" printf 'fail\\n' >> {hook_log}\n" + " exit 7\n" + "fi\n" + f"printf 'retry\\n' >> {hook_log}\n", + ) + tail_hook_path = write_hook_script( + hook_dir, + "post-rebase-tail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'tail\\n' >> {hook_log}\n", + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + post_rebase_hook=[retry_hook_path, tail_hook_path], + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + + args.continue_run = True + first_retry_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert first_retry_result is False + post_rebase_resume_state = resume_state.read_resume_state(tmpdir) + assert post_rebase_resume_state.phase == resume_state.ResumePhase.POST_REBASE + assert post_rebase_resume_state.next_hook_script_index == 1 + assert post_rebase_resume_state.hook_script_locations == [retry_hook_path, tail_hook_path] + + os.remove(fail_marker) + args.retry_failed_step = True + second_retry_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert second_retry_result is True + assert os.path.exists(os.path.join(tmpdir, "later-carry.txt")) + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + with open(hook_log, encoding="utf-8") as logged_hook: + assert logged_hook.read().splitlines() == ["fail", "retry", "tail"] + assert "failed with exit-code 7" in mocked_message_slack.call_args_list[-2].args[1] + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._is_push_required") + @patch("rebasebot.bot._message_slack") + def test_continue_skips_failed_pre_push_hook_script( + self, + mocked_message_slack, + mocked_is_push_required, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_push_required.return_value = True + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/8" + + CommitBuilder(source).add_file("new-upstream.txt", "new upstream state\n").commit("upstream moved") + + with TemporaryDirectory(prefix="rebasebot_pre_push_hook_") as hook_dir: + hook_log = os.path.join(hook_dir, "pre-push-hook.log") + fail_hook_path = write_hook_script( + hook_dir, + "pre-push-fail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'fail\\n' >> {hook_log}\n" + "exit 9\n", + ) + tail_hook_path = write_hook_script( + hook_dir, + "pre-push-tail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'tail\\n' >> {hook_log}\n", + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + pre_push_rebase_branch_hook=[fail_hook_path, tail_hook_path], + ) + + first_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert first_result is False + push_resume_state = resume_state.read_resume_state(tmpdir) + assert push_resume_state.phase == resume_state.ResumePhase.PRE_PUSH_REBASE_BRANCH + assert push_resume_state.next_hook_script_index == 1 + assert push_resume_state.hook_script_locations == [fail_hook_path, tail_hook_path] + assert mocked_push_rebase_branch.call_count == 0 + + args.continue_run = True + second_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert second_result is True + assert mocked_push_rebase_branch.call_count == 1 + assert mocked_create_pr.call_count == 1 + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + with open(hook_log, encoding="utf-8") as logged_hook: + assert logged_hook.read().splitlines() == ["fail", "tail"] + assert "failed with exit-code 9" in mocked_message_slack.call_args_list[-2].args[1] + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._is_push_required") + @patch("rebasebot.bot._message_slack") + def test_continue_skips_failed_pre_create_pr_hook_script( + self, + mocked_message_slack, + mocked_is_push_required, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_push_required.side_effect = [True, False] + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/9" + + CommitBuilder(source).add_file("new-upstream.txt", "new upstream state\n").commit("upstream moved") + + with TemporaryDirectory(prefix="rebasebot_pre_create_hook_") as hook_dir: + hook_log = os.path.join(hook_dir, "pre-create-hook.log") + fail_hook_path = write_hook_script( + hook_dir, + "pre-create-fail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'fail\\n' >> {hook_log}\n" + "exit 11\n", + ) + tail_hook_path = write_hook_script( + hook_dir, + "pre-create-tail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'tail\\n' >> {hook_log}\n", + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + pre_create_pr_hook=[fail_hook_path, tail_hook_path], + ) + + first_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert first_result is False + pre_create_resume_state = resume_state.read_resume_state(tmpdir) + assert pre_create_resume_state.phase == resume_state.ResumePhase.PRE_CREATE_PR + assert pre_create_resume_state.next_hook_script_index == 1 + assert pre_create_resume_state.hook_script_locations == [fail_hook_path, tail_hook_path] + assert mocked_push_rebase_branch.call_count == 1 + assert mocked_create_pr.call_count == 0 + + args.continue_run = True + second_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert second_result is True + assert mocked_push_rebase_branch.call_count == 1 + assert mocked_create_pr.call_count == 1 + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + with open(hook_log, encoding="utf-8") as logged_hook: + assert logged_hook.read().splitlines() == ["fail", "tail"] + assert "failed with exit-code 11" in mocked_message_slack.call_args_list[-2].args[1] + + @patch("rebasebot.bot._message_slack") + def test_continue_skips_failed_pre_rebase_hook_script( + self, + mocked_message_slack, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + CommitBuilder(source).add_file("new-upstream.txt", "new upstream state\n").commit("upstream moved") + CommitBuilder(dest).add_file("later-carry.txt", "later content\n").commit("UPSTREAM: : later carry") + + with TemporaryDirectory(prefix="rebasebot_pre_rebase_hook_") as hook_dir: + hook_log = os.path.join(hook_dir, "pre-rebase-hook.log") + fail_hook_path = write_hook_script( + hook_dir, + "pre-rebase-fail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'fail\\n' >> {hook_log}\n" + "exit 5\n", + ) + tail_hook_path = write_hook_script( + hook_dir, + "pre-rebase-tail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'tail\\n' >> {hook_log}\n", + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + pause_on_conflict=True, + pre_rebase_hook=[fail_hook_path, tail_hook_path], + ) + + first_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert first_result is False + pre_rebase_resume_state = resume_state.read_resume_state(tmpdir) + assert pre_rebase_resume_state.phase == resume_state.ResumePhase.PRE_REBASE + assert pre_rebase_resume_state.next_hook_script_index == 1 + assert pre_rebase_resume_state.hook_script_locations == [fail_hook_path, tail_hook_path] + + args.continue_run = True + second_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert second_result is True + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + with open(hook_log, encoding="utf-8") as logged_hook: + assert logged_hook.read().splitlines() == ["fail", "tail"] + assert "failed with exit-code 5" in mocked_message_slack.call_args_list[-1].args[1] + + @patch("rebasebot.bot._message_slack") + def test_continue_skips_failed_pre_carry_hook_script( + self, + mocked_message_slack, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + CommitBuilder(source).add_file("new-upstream.txt", "new upstream state\n").commit("upstream moved") + CommitBuilder(dest).add_file("later-carry.txt", "later content\n").commit("UPSTREAM: : later carry") + + with TemporaryDirectory(prefix="rebasebot_pre_carry_hook_") as hook_dir: + hook_log = os.path.join(hook_dir, "pre-carry-hook.log") + fail_hook_path = write_hook_script( + hook_dir, + "pre-carry-fail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'fail\\n' >> {hook_log}\n" + "exit 6\n", + ) + tail_hook_path = write_hook_script( + hook_dir, + "pre-carry-tail.sh", + "#!/bin/bash\n" + "set -eu\n" + f"printf 'tail\\n' >> {hook_log}\n", + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + pause_on_conflict=True, + pre_carry_commit_hook=[fail_hook_path, tail_hook_path], + ) + + first_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert first_result is False + pre_carry_resume_state = resume_state.read_resume_state(tmpdir) + assert pre_carry_resume_state.phase == resume_state.ResumePhase.PRE_CARRY_COMMIT + assert pre_carry_resume_state.next_hook_script_index == 1 + assert pre_carry_resume_state.hook_script_locations == [fail_hook_path, tail_hook_path] + + args.continue_run = True + second_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert second_result is True + assert not os.path.exists(resume_state.resume_state_path(tmpdir)) + with open(hook_log, encoding="utf-8") as logged_hook: + assert logged_hook.read().splitlines() == ["fail", "tail"] + assert "failed with exit-code 6" in mocked_message_slack.call_args_list[-1].args[1] + + @patch("rebasebot.bot._create_pr") + @patch("rebasebot.bot._push_rebase_branch") + @patch("rebasebot.bot._is_pr_available") + @patch("rebasebot.bot._is_push_required") + @patch("rebasebot.bot._message_slack") + def test_continue_rejects_untracked_artifacts_from_hook_failure( + self, + mocked_message_slack, + mocked_is_push_required, + mocked_is_pr_available, + mocked_push_rebase_branch, + mocked_create_pr, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + mocked_is_push_required.return_value = True + mocked_is_pr_available.return_value = None, False + mocked_push_rebase_branch.return_value = True + mocked_create_pr.return_value = "https://github.com/example/rebase/pull/10" + CommitBuilder(source).add_file("new-upstream.txt", "new upstream state\n").commit("upstream moved") + + with TemporaryDirectory(prefix="rebasebot_untracked_hook_") as hook_dir: + fail_hook_path = write_hook_script( + hook_dir, + "pre-push-artifact.sh", + "#!/bin/bash\n" + "set -eu\n" + "touch hook-artifact.txt\n" + "exit 13\n", + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pause_on_conflict=True, + pre_push_rebase_branch_hook=[fail_hook_path], + ) + + first_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert first_result is False + assert resume_state.read_resume_state(tmpdir).phase == resume_state.ResumePhase.PRE_PUSH_REBASE_BRANCH + + args.continue_run = True + second_result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert second_result is False + assert mocked_message_slack.call_count == 2 + assert "unexpected untracked files present: hook-artifact.txt" in mocked_message_slack.call_args_list[-1].args[1] diff --git a/tests/test_rebase_resume_validation.py b/tests/test_rebase_resume_validation.py new file mode 100644 index 0000000..e7514fe --- /dev/null +++ b/tests/test_rebase_resume_validation.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import os +from unittest.mock import patch + +import pytest +from git import Repo + +from rebasebot import bot, cli, resume_state + +from .conftest import CommitBuilder +from .rebase_test_support import make_rebasebot_args + + +class TestRebaseResumeValidation: + @patch("rebasebot.bot._message_slack") + def test_continue_rejects_wrong_branch( + self, + mocked_message_slack, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + pause_on_conflict=True, + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + working_repo.git.checkout("-b", "other-branch") + + args.continue_run = True + result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert result is False + assert mocked_message_slack.call_count == 2 + assert "Check out the local rebase branch first." in mocked_message_slack.call_args_list[-1].args[1] + + @patch("rebasebot.bot._message_slack") + def test_continue_rejects_unexpected_untracked_files( + self, + mocked_message_slack, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + pause_on_conflict=True, + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + with open(os.path.join(tmpdir, "rogue.txt"), "w", encoding="utf-8") as rogue_file: + rogue_file.write("unexpected\n") + + args.continue_run = True + result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert result is False + assert mocked_message_slack.call_count == 2 + assert "unexpected untracked files present: rogue.txt" in mocked_message_slack.call_args_list[-1].args[1] + + @pytest.mark.parametrize( + ("branch_to_advance", "expected_message"), + ( + ("source", "source branch advanced after the pause"), + ("dest", "destination branch advanced after the pause"), + ), + ) + @patch("rebasebot.bot._message_slack") + def test_continue_rejects_moved_branch_heads( + self, + mocked_message_slack, + branch_to_advance, + expected_message, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + pause_on_conflict=True, + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + + if branch_to_advance == "source": + CommitBuilder(source).add_file("source-advanced.txt", "new upstream state\n").commit("source moved") + else: + CommitBuilder(dest).add_file("dest-advanced.txt", "new downstream state\n").commit("dest moved") + + args.continue_run = True + result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert result is False + assert mocked_message_slack.call_count == 2 + assert expected_message in mocked_message_slack.call_args_list[-1].args[1] + + @patch("rebasebot.bot._message_slack") + def test_continue_validates_stale_heads_before_hook_fetch( + self, + mocked_message_slack, + init_test_repositories, + fake_github_provider, + tmpdir, + ): + source, rebase, dest = init_test_repositories + CommitBuilder(source).move_file("test.go", "source-test.go").commit("rename upstream file") + CommitBuilder(dest).move_file("test.go", "dest-test.go").commit("UPSTREAM: : downstream conflict") + + hook_name = "post-rebase-hook-script.sh" + CommitBuilder(dest).add_file(hook_name, "#!/bin/bash\nset -eu\ntouch should-not-run\n").commit( + "UPSTREAM: : add post hook" + ) + + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + pause_on_conflict=True, + post_rebase_hook=[f"git:dest/{dest.branch}:{hook_name}"], + ) + + with patch("rebasebot.bot._resolve_rebase_conflicts", return_value=False): + with pytest.raises(bot.PausedRebaseException): + cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + working_repo = Repo(tmpdir) + working_repo.git.rm("source-test.go", "test.go") + working_repo.git.add("dest-test.go") + working_repo.git.cherry_pick("--continue") + + CommitBuilder(dest).remove_file(hook_name).commit("remove hook after pause") + CommitBuilder(dest).add_file("dest-advanced.txt", "new downstream state\n").commit("dest moved") + + args.continue_run = True + result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) + + assert result is False + assert mocked_message_slack.call_count == 2 + assert "destination branch advanced after the pause" in mocked_message_slack.call_args_list[-1].args[1] + assert "Failed to fetch lifecycle hook scripts" not in mocked_message_slack.call_args_list[-1].args[1] diff --git a/tests/test_rebases.py b/tests/test_rebases.py index 99c71cd..5c0df3a 100644 --- a/tests/test_rebases.py +++ b/tests/test_rebases.py @@ -1,13 +1,13 @@ from __future__ import annotations import os -from dataclasses import dataclass +from tempfile import TemporaryDirectory from unittest.mock import ANY, MagicMock, patch import pytest from git import Repo -from rebasebot import cli +from rebasebot import bot, cli, resume_state from rebasebot.bot import ( _init_working_dir, _needs_rebase, @@ -16,19 +16,14 @@ from rebasebot.github import GitHubBranch, parse_github_branch from .conftest import CommitBuilder - - -@dataclass -class WorkingRepoContext: - source: GitHubBranch - rebase: GitHubBranch - dest: GitHubBranch - - working_repo: Repo - working_repo_path: str - - def fetch_remotes(self): - self.working_repo.git.fetch("--all") +from .rebase_test_support import ( + FakeArtCommit, + FakeArtPullRequest, + WorkingRepoContext, + make_rebasebot_args, + setup_fake_art_pr, + write_hook_script, +) class TestBotInternalHelpers: @@ -93,21 +88,13 @@ def test_simple_dry_run(self, init_test_repositories, fake_github_provider, tmpd source, rebase, dest = init_test_repositories CommitBuilder(source).add_file("baz.txt", "fiz").commit("other upstream commit") - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "soft" - args.bot_emails = [] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.ignore_manual_label = False - args.dry_run = True + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) assert result @@ -144,21 +131,14 @@ def test_squash_bot_dry_run(self, init_test_repositories, fake_github_provider, cb.add_file("generated-test3", "content") cb.commit("commit #1 from anotherbot", committer_email="anotherbot@example.com") - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "soft" - args.bot_emails = ["genbot@example.com", "anotherbot@example.com"] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.ignore_manual_label = False - args.dry_run = True + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + bot_emails=["genbot@example.com", "anotherbot@example.com"], + dry_run=True, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) assert result @@ -208,21 +188,13 @@ def test_first_run_dest_has_merges_dry_run(self, init_test_repositories, fake_gi cb.add_file("generated-test3", "content") cb.commit("commit #1 from anotherbot", committer_email="anotherbot@example.com") - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "soft" - args.bot_emails = [] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.ignore_manual_label = False - args.dry_run = True + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) assert result @@ -275,21 +247,13 @@ def test_first_run_dest_merges_feature_branch_dry_run(self, init_test_repositori repo.git.checkout(dest.branch) repo.git.merge("--no-ff", "-m", "Merge branch 'feature'", repo.heads.feature) - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "soft" - args.bot_emails = [] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.ignore_manual_label = False - args.dry_run = True + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + dry_run=True, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) assert result @@ -332,21 +296,13 @@ def test_conflict( with CommitBuilder(dest) as cb: cb.commit("Empty commit") - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "soft" - args.bot_emails = ["genbot@example.com", "anotherbot@example.com"] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.ignore_manual_label = False - args.dry_run = False + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + bot_emails=["genbot@example.com", "anotherbot@example.com"], + ) result = cli.rebasebot_run(args, slack_webhook="test://webhook", github_app_wrapper=fake_github_provider) assert result @@ -400,21 +356,12 @@ def fake_repository_func(namespace, name): fake_github_provider.github_app.repository = fake_repository_func - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "soft" - args.bot_emails = [] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.ignore_manual_label = False - args.dry_run = False + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) mocked_message_slack.assert_called_once_with( None, @@ -438,20 +385,16 @@ def test_strict_and_excluded_commits(self, init_test_repositories, fake_github_p cb.add_file("carry-file2", "content") drop_commit = cb.commit("UPSTREAM: : dropped by exclude_commits") - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "strict" - args.bot_emails = ["genbot@example.com", "anotherbot@example.com"] - args.exclude_commits = [drop_commit.hexsha] - args.update_go_modules = False - args.conflict_policy = "auto" - args.dry_run = True + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + tag_policy="strict", + bot_emails=["genbot@example.com", "anotherbot@example.com"], + exclude_commits=[drop_commit.hexsha], + dry_run=True, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) assert result @@ -500,21 +443,16 @@ def test_lifecyclehooks_remote( cb.add_file("carry-file1", "content") cb.commit("UPSTREAM: : carry commit #1") - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "strict" - args.bot_emails = ["genbot@example.com", "anotherbot@example.com"] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.dry_run = True - args.post_rebase_hook = ["git:https://github.com/openshift-eng/rebasebot/main:tests/data/test-hook-script.sh"] # noqa: E501 + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + tag_policy="strict", + bot_emails=["genbot@example.com", "anotherbot@example.com"], + dry_run=True, + post_rebase_hook=["git:https://github.com/openshift-eng/rebasebot/main:tests/data/test-hook-script.sh"], # noqa: E501 + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) @@ -566,21 +504,16 @@ def test_lifecyclehooks(self, init_test_repositories, fake_github_provider, tmpd ) cb.commit("UPSTREAM: : add test hook script") - args = MagicMock() - args.source = source - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "strict" - args.bot_emails = ["genbot@example.com", "anotherbot@example.com"] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.dry_run = True - args.post_rebase_hook = [f"git:dest/{dest.branch}:test-hook-script.sh"] - args.source_repo = None + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + tag_policy="strict", + bot_emails=["genbot@example.com", "anotherbot@example.com"], + dry_run=True, + post_rebase_hook=[f"git:dest/{dest.branch}:test-hook-script.sh"], + ) assert cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) @@ -620,21 +553,16 @@ def test_lifecyclehook_fail(self, init_test_repositories, fake_github_provider, ) cb.commit("UPSTREAM: : add test hook script") - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.pre_rebase_hook = [f"git:dest/{dest.branch}:test-failure-hook-script.sh"] - args.tag_policy = "strict" - args.bot_emails = ["genbot@example.com", "anotherbot@example.com"] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.dry_run = True + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + pre_rebase_hook=[f"git:dest/{dest.branch}:test-failure-hook-script.sh"], + tag_policy="strict", + bot_emails=["genbot@example.com", "anotherbot@example.com"], + dry_run=True, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) @@ -677,22 +605,17 @@ def fake_parse_github_branch(location): mock_parse_github_branch_hooks.side_effect = fake_parse_github_branch mock_parse_github_branch_cli.side_effect = fake_parse_github_branch - args = MagicMock() - args.source = None - args.source_repo = f"{source.ns}/{source.name}" url = "git:https://github.com/openshift-eng/rebasebot/main:tests/data/test-source-ref-hook-script.sh" - args.source_ref_hook = url - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "strict" - args.bot_emails = [] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.dry_run = True + args = make_rebasebot_args( + source=None, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + source_repo=f"{source.ns}/{source.name}", + source_ref_hook=url, + tag_policy="strict", + dry_run=True, + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider) @@ -716,29 +639,18 @@ def test_always_run_hooks_when_no_rebase_needed(self, init_test_repositories, fa cb.commit("UPSTREAM: : add test hook scripts") # Configure args with always_run_hooks=True and multiple hook types to test - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "none" - args.bot_emails = [] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.dry_run = True - args.ignore_manual_label = True - args.always_run_hooks = True - - # Test multiple hook types with different scripts - args.pre_rebase_hook = [f"git:dest/{dest.branch}:pre-rebase-hook-script.sh"] - args.post_rebase_hook = [f"git:dest/{dest.branch}:post-rebase-hook-script.sh"] - args.pre_carry_commit_hook = None - args.pre_push_rebase_branch_hook = None - args.pre_create_pr_hook = None + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + tag_policy="none", + dry_run=True, + ignore_manual_label=True, + always_run_hooks=True, + pre_rebase_hook=[f"git:dest/{dest.branch}:pre-rebase-hook-script.sh"], + post_rebase_hook=[f"git:dest/{dest.branch}:post-rebase-hook-script.sh"], + ) # Verify no rebase is needed initially (source and dest are in sync) # But hooks should still run due to always_run_hooks=True @@ -771,28 +683,17 @@ def test_hooks_not_run_when_no_rebase_needed_and_flag_false( cb.commit("UPSTREAM: : add test hook scripts") # Configure args with always_run_hooks=False (default behavior) - args = MagicMock() - args.source = source - args.source_repo = None - args.dest = dest - args.rebase = rebase - args.working_dir = tmpdir - args.git_username = "test_rebasebot" - args.git_email = "test@rebasebot.ocp" - args.tag_policy = "none" - args.bot_emails = [] - args.exclude_commits = [] - args.update_go_modules = False - args.conflict_policy = "auto" - args.dry_run = True - args.ignore_manual_label = True - args.always_run_hooks = False # Key difference: hooks should NOT run - - args.pre_rebase_hook = [f"git:dest/{dest.branch}:pre-rebase-hook-script.sh"] - args.post_rebase_hook = [f"git:dest/{dest.branch}:post-rebase-hook-script.sh"] - args.pre_carry_commit_hook = None - args.pre_push_rebase_branch_hook = None - args.pre_create_pr_hook = None + args = make_rebasebot_args( + source=source, + dest=dest, + rebase=rebase, + working_dir=tmpdir, + tag_policy="none", + dry_run=True, + ignore_manual_label=True, + pre_rebase_hook=[f"git:dest/{dest.branch}:pre-rebase-hook-script.sh"], + post_rebase_hook=[f"git:dest/{dest.branch}:post-rebase-hook-script.sh"], + ) result = cli.rebasebot_run(args, slack_webhook=None, github_app_wrapper=fake_github_provider)