diff --git a/dvc/cache/base.py b/dvc/cache/base.py index 422a2548ad..38e84869c5 100644 --- a/dvc/cache/base.py +++ b/dvc/cache/base.py @@ -5,7 +5,12 @@ from shortuuid import uuid import dvc.prompt as prompt -from dvc.exceptions import CheckoutError, ConfirmRemoveError, DvcException +from dvc.exceptions import ( + CheckoutError, + ConfirmRemoveError, + DvcException, + MergeError, +) from dvc.path_info import WindowsPathInfo from dvc.progress import Tqdm from dvc.remote.slow_link_detection import slow_link_guard @@ -552,3 +557,82 @@ def get_files_number(self, path_info, hash_, filter_info): filter_info.isin_or_eq(path_info / entry[self.tree.PARAM_CHECKSUM]) for entry in self.get_dir_cache(hash_) ) + + def _to_dict(self, dir_info): + return { + entry[self.tree.PARAM_RELPATH]: entry[self.tree.PARAM_CHECKSUM] + for entry in dir_info + } + + def _from_dict(self, dir_dict): + return [ + { + self.tree.PARAM_RELPATH: relpath, + self.tree.PARAM_CHECKSUM: checksum, + } + for relpath, checksum in dir_dict.items() + ] + + @staticmethod + def _diff(ancestor, other, allow_removed=False): + from dictdiffer import diff + + allowed = ["add"] + if allow_removed: + allowed.append("remove") + + result = list(diff(ancestor, other)) + for typ, _, _ in result: + if typ not in allowed: + raise MergeError( + "unable to auto-merge directories with diff that contains " + f"'{typ}'ed files" + ) + return result + + def _merge_dirs(self, ancestor_info, our_info, their_info): + from operator import itemgetter + + from dictdiffer import patch + + ancestor = self._to_dict(ancestor_info) + our = self._to_dict(our_info) + their = self._to_dict(their_info) + + our_diff = self._diff(ancestor, our) + if not our_diff: + return self._from_dict(their) + + their_diff = self._diff(ancestor, their) + if not their_diff: + return self._from_dict(our) + + # make sure there are no conflicting files + self._diff(our, their, allow_removed=True) + + merged = patch(our_diff + their_diff, ancestor, in_place=True) + + # Sorting the list by path to ensure reproducibility + return sorted( + self._from_dict(merged), key=itemgetter(self.tree.PARAM_RELPATH) + ) + + def merge(self, ancestor_info, our_info, their_info): + assert our_info + assert their_info + + if ancestor_info: + ancestor_hash = ancestor_info[self.tree.PARAM_CHECKSUM] + ancestor = self.get_dir_cache(ancestor_hash) + else: + ancestor = [] + + our_hash = our_info[self.tree.PARAM_CHECKSUM] + our = self.get_dir_cache(our_hash) + + their_hash = their_info[self.tree.PARAM_CHECKSUM] + their = self.get_dir_cache(their_hash) + + merged = self._merge_dirs(ancestor, our, their) + typ, merged_hash = self.tree.save_dir_info(merged) + return {typ: merged_hash} diff --git a/dvc/command/git_hook.py b/dvc/command/git_hook.py index e82cc6f624..0b2cc4ad92 100644 --- a/dvc/command/git_hook.py +++ b/dvc/command/git_hook.py @@ -61,6 +61,26 @@ def _run(self): return main(["push"]) +class CmdMergeDriver(CmdHookBase): + def _run(self): + from dvc.dvcfile import Dvcfile + from dvc.repo import Repo + + dvc = Repo() + + try: + with dvc.state: + ancestor = Dvcfile(dvc, self.args.ancestor, verify=False) + our = Dvcfile(dvc, self.args.our, verify=False) + their = Dvcfile(dvc, self.args.their, verify=False) + + our.merge(ancestor, their) + + return 0 + finally: + dvc.close() + + def add_parser(subparsers, parent_parser): GIT_HOOK_HELP = "Run GIT hook." @@ -113,3 +133,27 @@ def add_parser(subparsers, parent_parser): "args", nargs="*", help="Arguments passed by GIT or pre-commit tool.", ) pre_push_parser.set_defaults(func=CmdPrePush) + + MERGE_DRIVER_HELP = "Run GIT merge driver." + merge_driver_parser = git_hook_subparsers.add_parser( + "merge-driver", + parents=[parent_parser], + description=MERGE_DRIVER_HELP, + help=MERGE_DRIVER_HELP, + ) + merge_driver_parser.add_argument( + "--ancestor", + required=True, + help="Ancestor's version of the conflicting file.", + ) + merge_driver_parser.add_argument( + "--our", + required=True, + help="Current version of the conflicting file.", + ) + merge_driver_parser.add_argument( + "--their", + required=True, + help="Other branch's version of the conflicting file.", + ) + merge_driver_parser.set_defaults(func=CmdMergeDriver) diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index 697b8fc42d..f523192544 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -56,9 +56,10 @@ def check_dvc_filename(path): class FileMixin: SCHEMA = None - def __init__(self, repo, path, **kwargs): + def __init__(self, repo, path, verify=True, **kwargs): self.repo = repo self.path = path + self.verify = verify def __repr__(self): return "{}: {}".format( @@ -90,7 +91,8 @@ def _load(self): # 3. path doesn't represent a regular file if not self.exists(): raise StageFileDoesNotExistError(self.path) - check_dvc_filename(self.path) + if self.verify: + check_dvc_filename(self.path) if not self.repo.tree.isfile(self.path): raise StageFileIsNotDvcFileError(self.path) @@ -115,6 +117,9 @@ def remove(self, force=False): # pylint: disable=unused-argument def dump(self, stage, **kwargs): raise NotImplementedError + def merge(self, ancestor, other): + raise NotImplementedError + class SingleStageFile(FileMixin): from dvc.schema import COMPILED_SINGLE_STAGE_SCHEMA as SCHEMA @@ -134,7 +139,8 @@ def dump(self, stage, **kwargs): from dvc.stage import PipelineStage assert not isinstance(stage, PipelineStage) - check_dvc_filename(self.path) + if self.verify: + check_dvc_filename(self.path) logger.debug( "Saving information to '{file}'.".format(file=relpath(self.path)) ) @@ -144,6 +150,14 @@ def dump(self, stage, **kwargs): def remove_stage(self, stage): # pylint: disable=unused-argument self.remove() + def merge(self, ancestor, other): + assert isinstance(ancestor, SingleStageFile) + assert isinstance(other, SingleStageFile) + + stage = self.stage + stage.merge(ancestor.stage, other.stage) + self.dump(stage) + class PipelineFile(FileMixin): """Abstraction for pipelines file, .yaml + .lock combined.""" @@ -161,7 +175,8 @@ def dump( from dvc.stage import PipelineStage assert isinstance(stage, PipelineStage) - check_dvc_filename(self.path) + if self.verify: + check_dvc_filename(self.path) if update_pipeline and not stage.is_data_source: self._dump_pipeline_file(stage) @@ -239,6 +254,9 @@ def remove_stage(self, stage): else: super().remove() + def merge(self, ancestor, other): + raise NotImplementedError + class Lockfile(FileMixin): from dvc.schema import COMPILED_LOCKFILE_SCHEMA as SCHEMA @@ -295,6 +313,9 @@ def remove_stage(self, stage): else: self.remove() + def merge(self, ancestor, other): + raise NotImplementedError + class Dvcfile: def __new__(cls, repo, path, **kwargs): diff --git a/dvc/exceptions.py b/dvc/exceptions.py index d6586a35db..8e6e7c78b0 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -351,3 +351,7 @@ def __init__(self, target, file): f"'{target}' " f"does not exist as an output or a stage name in '{file}'" ) + + +class MergeError(DvcException): + pass diff --git a/dvc/output/base.py b/dvc/output/base.py index 54f30b54af..01572a4b55 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -10,6 +10,7 @@ from dvc.exceptions import ( CollectCacheError, DvcException, + MergeError, RemoteCacheRequiredError, ) @@ -516,3 +517,37 @@ def _validate_output_path(cls, path, stage=None): check = stage.repo.tree.dvcignore.check_ignore(path) if check.match: raise cls.IsIgnoredError(check) + + def _check_can_merge(self, out): + if self.scheme != out.scheme: + raise MergeError("unable to auto-merge outputs of different types") + + my = self.dumpd() + other = out.dumpd() + + my.pop(self.tree.PARAM_CHECKSUM) + other.pop(self.tree.PARAM_CHECKSUM) + + if my != other: + raise MergeError( + "unable to auto-merge outputs with different options" + ) + + if not out.is_dir_checksum: + raise MergeError( + "unable to auto-merge outputs that are not directories" + ) + + def merge(self, ancestor, other): + assert other + + if ancestor: + self._check_can_merge(ancestor) + ancestor_info = ancestor.info + else: + ancestor_info = None + + self._check_can_merge(self) + self._check_can_merge(other) + + self.info = self.cache.merge(ancestor_info, self.info, other.info) diff --git a/dvc/scm/git.py b/dvc/scm/git.py index 5df270bdae..775b4dc0b5 100644 --- a/dvc/scm/git.py +++ b/dvc/scm/git.py @@ -307,7 +307,21 @@ def _install_hook(self, name): os.chmod(hook, 0o777) + def _install_merge_driver(self): + self.repo.git.config("merge.dvc.name", "DVC merge driver") + self.repo.git.config( + "merge.dvc.driver", + ( + "dvc git-hook merge-driver " + "--ancestor %O " + "--our %A " + "--their %B " + ), + ) + def install(self, use_pre_commit_tool=False): + self._install_merge_driver() + if not use_pre_commit_tool: self._verify_dvc_hooks() self._install_hook("post-checkout") diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 61f9c11914..5f485301d6 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -7,7 +7,7 @@ import dvc.dependency as dependency import dvc.prompt as prompt -from dvc.exceptions import CheckoutError, DvcException +from dvc.exceptions import CheckoutError, DvcException, MergeError from dvc.utils import relpath from . import params @@ -538,6 +538,44 @@ def get_used_cache(self, *args, **kwargs): return cache + @staticmethod + def _check_can_merge(stage, ancestor_out=None): + if isinstance(stage, PipelineStage): + raise MergeError("unable to auto-merge pipeline stages") + + if not stage.is_data_source or stage.deps or len(stage.outs) > 1: + raise MergeError( + "unable to auto-merge DVC-files that weren't " + "created by `dvc add`" + ) + + if ancestor_out and not stage.outs: + raise MergeError( + "unable to auto-merge DVC-files with deleted outputs" + ) + + def merge(self, ancestor, other): + assert other + + if not other.outs: + return + + if not self.outs: + self.outs = other.outs + return + + if ancestor: + self._check_can_merge(ancestor) + outs = ancestor.outs + ancestor_out = outs[0] if outs else None + else: + ancestor_out = None + + self._check_can_merge(self, ancestor_out) + self._check_can_merge(other, ancestor_out) + + self.outs[0].merge(ancestor_out, other.outs[0]) + class PipelineStage(Stage): def __init__(self, *args, name=None, **kwargs): @@ -577,3 +615,6 @@ def changed_stage(self): def _changed_stage_entry(self): return f"'cmd' of {self} has changed." + + def merge(self, ancestor, other): + raise NotImplementedError diff --git a/dvc/tree/base.py b/dvc/tree/base.py index 3494e3272f..41f0e1a90f 100644 --- a/dvc/tree/base.py +++ b/dvc/tree/base.py @@ -280,7 +280,7 @@ def get_dir_hash(self, path_info, **kwargs): raise RemoteCacheRequiredError(path_info) dir_info = self._collect_dir(path_info, **kwargs) - return self._save_dir_info(dir_info) + return self.save_dir_info(dir_info) def hash_to_path_info(self, hash_): return self.path_info / hash_[0:2] / hash_[2:] @@ -345,7 +345,7 @@ def _collect_dir(self, path_info, **kwargs): # Sorting the list by path to ensure reproducibility return sorted(result, key=itemgetter(self.PARAM_RELPATH)) - def _save_dir_info(self, dir_info): + def save_dir_info(self, dir_info): typ, hash_, tmp_info = self._get_dir_info_hash(dir_info) new_info = self.cache.tree.hash_to_path_info(hash_) if self.cache.changed_cache_file(hash_): @@ -359,6 +359,9 @@ def _save_dir_info(self, dir_info): return typ, hash_ def _get_dir_info_hash(self, dir_info): + # Sorting the list by path to ensure reproducibility + dir_info = sorted(dir_info, key=itemgetter(self.PARAM_RELPATH)) + tmp = tempfile.NamedTemporaryFile(delete=False).name with open(tmp, "w+") as fobj: json.dump(dir_info, fobj, sort_keys=True) diff --git a/setup.py b/setup.py index 687f7353c1..f7117037d0 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ def run(self): "dpath>=2.0.1,<3", "shtab>=1.3.0,<2", "rich>=3.0.5", + "dictdiffer>=0.8.1", ] diff --git a/tests/func/test_install.py b/tests/func/test_install.py index d71f26dcf6..80be612140 100644 --- a/tests/func/test_install.py +++ b/tests/func/test_install.py @@ -68,3 +68,65 @@ def test_pre_push_hook(self, tmp_dir, scm, dvc, tmp_path_factory): scm.repo.git.push("origin", "master") assert expected_storage_path.is_file() assert expected_storage_path.read_text() == "file_content" + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Git hooks aren't supported on Windows" +) +def test_merge_driver_no_ancestor(tmp_dir, scm, dvc): + scm.commit("init") + scm.install() + (tmp_dir / ".gitattributes").write_text("*.dvc merge=dvc") + scm.checkout("one", create_new=True) + tmp_dir.dvc_gen({"data": {"foo": "foo"}}, commit="one: add data") + + scm.checkout("master") + scm.checkout("two", create_new=True) + tmp_dir.dvc_gen({"data": {"bar": "bar"}}, commit="two: add data") + + scm.repo.git.merge("one", m="merged") + + # NOTE: dvc shouldn't checkout automatically as it might take a long time + assert (tmp_dir / "data").read_text() == {"bar": "bar"} + assert (tmp_dir / "data.dvc").read_text() == ( + "outs:\n" + "- md5: 5ea40360f5b4ec688df672a4db9c17d1.dir\n" + " path: data\n" + ) + + dvc.checkout("data.dvc") + assert (tmp_dir / "data").read_text() == {"foo": "foo", "bar": "bar"} + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Git hooks aren't supported on Windows" +) +def test_merge_driver(tmp_dir, scm, dvc): + scm.commit("init") + scm.install() + (tmp_dir / ".gitattributes").write_text("*.dvc merge=dvc") + tmp_dir.dvc_gen({"data": {"master": "master"}}, commit="master: add data") + + scm.checkout("one", create_new=True) + tmp_dir.dvc_gen({"data": {"one": "one"}}, commit="one: add data") + + scm.checkout("master") + scm.checkout("two", create_new=True) + tmp_dir.dvc_gen({"data": {"two": "two"}}, commit="two: add data") + + scm.repo.git.merge("one", m="merged") + + # NOTE: dvc shouldn't checkout automatically as it might take a long time + assert (tmp_dir / "data").read_text() == {"master": "master", "two": "two"} + assert (tmp_dir / "data.dvc").read_text() == ( + "outs:\n" + "- md5: 839ef9371606817569c1ee0e5f4ed233.dir\n" + " path: data\n" + ) + + dvc.checkout("data.dvc") + assert (tmp_dir / "data").read_text() == { + "master": "master", + "one": "one", + "two": "two", + } diff --git a/tests/func/test_merge_driver.py b/tests/func/test_merge_driver.py new file mode 100644 index 0000000000..e0e631653d --- /dev/null +++ b/tests/func/test_merge_driver.py @@ -0,0 +1,241 @@ +import os + +import pytest + +from dvc.main import main +from dvc.utils.fs import remove + + +def _gen(tmp_dir, struct, name): + remove(tmp_dir / "data") + if struct is None: + (tmp_dir / name).touch() + else: + (stage,) = tmp_dir.dvc_gen({"data": struct}) + os.rename(stage.path, name) + + +@pytest.mark.parametrize( + "ancestor, our, their, merged", + [ + ( + {"foo": "foo"}, + {"foo": "foo", "bar": "bar"}, + {"foo": "foo", "baz": "baz"}, + {"foo": "foo", "bar": "bar", "baz": "baz"}, + ), + ( + {"common": "common", "subdir": {"foo": "foo"}}, + {"common": "common", "subdir": {"foo": "foo", "bar": "bar"}}, + {"common": "common", "subdir": {"foo": "foo", "baz": "baz"}}, + { + "common": "common", + "subdir": {"foo": "foo", "bar": "bar", "baz": "baz"}, + }, + ), + ({}, {"foo": "foo"}, {"bar": "bar"}, {"foo": "foo", "bar": "bar"},), + ({}, {}, {"bar": "bar"}, {"bar": "bar"},), + ({}, {"foo": "foo"}, {}, {"foo": "foo"},), + (None, {"foo": "foo"}, {"bar": "bar"}, {"foo": "foo", "bar": "bar"},), + (None, None, {"bar": "bar"}, {"bar": "bar"},), + (None, {"foo": "foo"}, None, {"foo": "foo"},), + ], +) +def test_merge(tmp_dir, dvc, ancestor, our, their, merged): + _gen(tmp_dir, ancestor, "ancestor") + _gen(tmp_dir, our, "our") + _gen(tmp_dir, their, "their") + + assert ( + main( + [ + "git-hook", + "merge-driver", + "--ancestor", + "ancestor", + "--our", + "our", + "--their", + "their", + ] + ) + == 0 + ) + + _gen(tmp_dir, merged, "merged") + + assert (tmp_dir / "our").read_text() == (tmp_dir / "merged").read_text() + + +@pytest.mark.parametrize( + "ancestor, our, their, error", + [ + ( + {"foo": "foo"}, + {"foo": "bar"}, + {"foo": "baz"}, + ( + "unable to auto-merge directories with " + "diff that contains 'change'ed files" + ), + ), + ( + {"common": "common", "foo": "foo"}, + {"common": "common", "bar": "bar"}, + {"baz": "baz"}, + ( + "unable to auto-merge directories with " + "diff that contains 'remove'ed files" + ), + ), + ], +) +def test_merge_conflict(tmp_dir, dvc, ancestor, our, their, error, caplog): + _gen(tmp_dir, ancestor, "ancestor") + _gen(tmp_dir, our, "our") + _gen(tmp_dir, their, "their") + + assert ( + main( + [ + "git-hook", + "merge-driver", + "--ancestor", + "ancestor", + "--our", + "our", + "--their", + "their", + ] + ) + != 0 + ) + + assert error in caplog.text + + +@pytest.mark.parametrize( + "workspace", [pytest.lazy_fixture("ssh")], indirect=True +) +def test_merge_different_output_types(tmp_dir, dvc, caplog, workspace): + (tmp_dir / "ancestor").touch() + + (tmp_dir / "our").write_text( + "outs:\n- md5: f123456789.dir\n path: ssh://example.com/path\n" + ) + + (tmp_dir / "their").write_text( + "outs:\n- md5: f987654321.dir\n path: path\n" + ) + + assert ( + main( + [ + "git-hook", + "merge-driver", + "--ancestor", + "ancestor", + "--our", + "our", + "--their", + "their", + ] + ) + != 0 + ) + + error = "unable to auto-merge outputs of different types" + assert error in caplog.text + + +def test_merge_different_output_options(tmp_dir, dvc, caplog): + (tmp_dir / "ancestor").touch() + + (tmp_dir / "our").write_text( + "outs:\n- md5: f123456789.dir\n path: path\n" + ) + + (tmp_dir / "their").write_text( + "outs:\n- md5: f987654321.dir\n path: path\n cache: false\n" + ) + + assert ( + main( + [ + "git-hook", + "merge-driver", + "--ancestor", + "ancestor", + "--our", + "our", + "--their", + "their", + ] + ) + != 0 + ) + + error = "unable to auto-merge outputs with different options" + assert error in caplog.text + + +def test_merge_file(tmp_dir, dvc, caplog): + (tmp_dir / "ancestor").touch() + + (tmp_dir / "our").write_text( + "outs:\n- md5: f123456789.dir\n path: path\n" + ) + + (tmp_dir / "their").write_text("outs:\n- md5: f987654321\n path: path\n") + + assert ( + main( + [ + "git-hook", + "merge-driver", + "--ancestor", + "ancestor", + "--our", + "our", + "--their", + "their", + ] + ) + != 0 + ) + + err = "unable to auto-merge outputs that are not directories" + assert err in caplog.text + + +def test_merge_non_dvc_add(tmp_dir, dvc, caplog): + (tmp_dir / "ancestor").touch() + + (tmp_dir / "our").write_text( + "outs:\n" + "- md5: f123456789.dir\n" + " path: path\n" + "- md5: ff123456789.dir\n" + " path: another\n" + ) + + (tmp_dir / "their").write_text("outs:\n- md5: f987654321\n path: path\n") + + assert ( + main( + [ + "git-hook", + "merge-driver", + "--ancestor", + "ancestor", + "--our", + "our", + "--their", + "their", + ] + ) + != 0 + ) + + error = "unable to auto-merge DVC-files that weren't created by `dvc add`" + assert error in caplog.text