From d778071ed63322b3124403ecb33e1c6e470bc673 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Fri, 17 Jun 2022 17:20:24 +0900 Subject: [PATCH 1/2] scmrepo: bump to 0.0.25 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index e2547d87da..accafdf6c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,7 +66,7 @@ install_requires = typing-extensions>=3.7.4 fsspec[http]>=2021.10.1 aiohttp-retry>=2.4.5 - scmrepo==0.0.24 + scmrepo==0.0.25 dvc-render==0.0.6 dvclive>=0.7.3 dvc-data==0.0.6 From 09500a71d864632168c3b433aa1a50e4367c6f5e Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Fri, 17 Jun 2022 17:22:28 +0900 Subject: [PATCH 2/2] exp apply: preserve untracked files --- dvc/repo/experiments/apply.py | 87 ++++++++++++++-------- tests/func/experiments/test_experiments.py | 19 +++++ 2 files changed, 75 insertions(+), 31 deletions(-) diff --git a/dvc/repo/experiments/apply.py b/dvc/repo/experiments/apply.py index 2e0e9cec55..71fb28ee15 100644 --- a/dvc/repo/experiments/apply.py +++ b/dvc/repo/experiments/apply.py @@ -1,5 +1,7 @@ import logging import os +from contextlib import contextmanager +from typing import TYPE_CHECKING, Optional from dvc.repo import locked from dvc.repo.scm_context import scm_context @@ -13,13 +15,18 @@ ) from .executor.base import BaseExecutor +if TYPE_CHECKING: + from scmrepo import Git + + from dvc.repo import Repo + logger = logging.getLogger(__name__) @locked @scm_context -def apply(repo, rev, force=True, **kwargs): - from scmrepo.exceptions import MergeConflictError +def apply(repo: "Repo", rev: str, force: bool = True, **kwargs): + from scmrepo.exceptions import SCMError as _SCMError from dvc.repo.checkout import checkout as dvc_checkout from dvc.scm import GitMergeError, RevError, resolve_rev @@ -38,37 +45,14 @@ def apply(repo, rev, force=True, **kwargs): ): raise InvalidExpRevError(exp_rev) - # Note that we don't use stash_workspace() here since we need finer control - # over the merge behavior when we unstash everything - if repo.scm.is_dirty(untracked_files=True): - logger.debug("Stashing workspace") - workspace = repo.scm.stash.push(include_untracked=True) - else: - workspace = None - - from scmrepo.exceptions import SCMError as _SCMError - - try: - repo.scm.merge(exp_rev, commit=False, squash=True) - except _SCMError as exc: - raise GitMergeError(str(exc), scm=repo.scm) - - if workspace: + # NOTE: we don't use scmrepo's stash_workspace() here since we need + # finer control over the merge behavior when we unstash everything + with _apply_workspace(repo, rev, force): try: - repo.scm.stash.apply(workspace) - except MergeConflictError as exc: - # Applied experiment conflicts with user's workspace changes - if force: - # prefer applied experiment changes over prior stashed changes - repo.scm.checkout_index(ours=True) - else: - # revert applied changes and restore user's workspace - repo.scm.reset(hard=True) - repo.scm.stash.pop() - raise ApplyConflictError(rev) from exc + repo.scm.merge(exp_rev, commit=False, squash=True) except _SCMError as exc: - raise ApplyConflictError(rev) from exc - repo.scm.stash.drop() + raise GitMergeError(str(exc), scm=repo.scm) + repo.scm.reset() if stash_rev: @@ -84,3 +68,44 @@ def apply(repo, rev, force=True, **kwargs): "workspace.", rev, ) + + +@contextmanager +def _apply_workspace(repo: "Repo", rev: str, force: bool): + from scmrepo.exceptions import MergeConflictError + from scmrepo.exceptions import SCMError as _SCMError + + if repo.scm.is_dirty(untracked_files=True): + logger.debug("Stashing workspace") + stash_rev: Optional[str] = repo.scm.stash.push(include_untracked=True) + else: + stash_rev = None + try: + yield + except Exception: # pylint: disable=broad-except + if stash_rev: + _clean_and_pop(repo.scm) + raise + if not stash_rev: + return + + try: + repo.scm.reset() + repo.scm.stash.apply(stash_rev, skip_conflicts=force) + repo.scm.stash.drop() + except (MergeConflictError, _SCMError) as exc: + _clean_and_pop(repo.scm) + raise ApplyConflictError(rev) from exc + except Exception: # pylint: disable=broad-except + _clean_and_pop(repo.scm) + raise + + +def _clean_and_pop(scm: "Git"): + """Revert any changes and pop the last stash entry.""" + scm.reset(hard=True) + if scm.is_dirty(untracked_files=True): + # drop any changes to untracked files before popping stash + scm.stash.push(include_untracked=True) + scm.stash.drop() + scm.stash.pop() diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 727022bbf9..6ee62ff72e 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -217,6 +217,25 @@ def test_apply(tmp_dir, scm, dvc, exp_stage, queue): ) +def test_apply_untracked(tmp_dir, scm, dvc, exp_stage): + from dvc.repo.experiments.base import ApplyConflictError + + results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"]) + exp = first(results) + tmp_dir.gen("untracked", "untracked") + tmp_dir.gen("params.yaml", "conflict") + + with pytest.raises(ApplyConflictError): + dvc.experiments.apply(exp, force=False) + + assert (tmp_dir / "untracked").read_text() == "untracked" + assert (tmp_dir / "params.yaml").read_text() == "conflict" + + dvc.experiments.apply(exp, force=True) + assert (tmp_dir / "untracked").read_text() == "untracked" + assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 2" + + def test_get_baseline(tmp_dir, scm, dvc, exp_stage): from dvc.repo.experiments.base import EXPS_STASH