Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 56 additions & 31 deletions dvc/repo/experiments/apply.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional

from dvc.repo import locked
from dvc.repo.scm_context import scm_context
Expand All @@ -13,13 +15,18 @@
)
from .executor.base import BaseExecutor

if TYPE_CHECKING:
from scmrepo import Git

from dvc.repo import Repo

logger = logging.getLogger(__name__)


@locked
@scm_context
def apply(repo, rev, force=True, **kwargs):
from scmrepo.exceptions import MergeConflictError
def apply(repo: "Repo", rev: str, force: bool = True, **kwargs):
from scmrepo.exceptions import SCMError as _SCMError

from dvc.repo.checkout import checkout as dvc_checkout
from dvc.scm import GitMergeError, RevError, resolve_rev
Expand All @@ -38,37 +45,14 @@ def apply(repo, rev, force=True, **kwargs):
):
raise InvalidExpRevError(exp_rev)

# Note that we don't use stash_workspace() here since we need finer control
# over the merge behavior when we unstash everything
if repo.scm.is_dirty(untracked_files=True):
logger.debug("Stashing workspace")
workspace = repo.scm.stash.push(include_untracked=True)
else:
workspace = None

from scmrepo.exceptions import SCMError as _SCMError

try:
repo.scm.merge(exp_rev, commit=False, squash=True)
except _SCMError as exc:
raise GitMergeError(str(exc), scm=repo.scm)

if workspace:
# NOTE: we don't use scmrepo's stash_workspace() here since we need
# finer control over the merge behavior when we unstash everything
with _apply_workspace(repo, rev, force):
try:
repo.scm.stash.apply(workspace)
except MergeConflictError as exc:
# Applied experiment conflicts with user's workspace changes
if force:
# prefer applied experiment changes over prior stashed changes
repo.scm.checkout_index(ours=True)
else:
# revert applied changes and restore user's workspace
repo.scm.reset(hard=True)
repo.scm.stash.pop()
raise ApplyConflictError(rev) from exc
repo.scm.merge(exp_rev, commit=False, squash=True)
except _SCMError as exc:
raise ApplyConflictError(rev) from exc
repo.scm.stash.drop()
raise GitMergeError(str(exc), scm=repo.scm)

repo.scm.reset()

if stash_rev:
Expand All @@ -84,3 +68,44 @@ def apply(repo, rev, force=True, **kwargs):
"workspace.",
rev,
)


@contextmanager
def _apply_workspace(repo: "Repo", rev: str, force: bool):
from scmrepo.exceptions import MergeConflictError
from scmrepo.exceptions import SCMError as _SCMError

if repo.scm.is_dirty(untracked_files=True):
logger.debug("Stashing workspace")
stash_rev: Optional[str] = repo.scm.stash.push(include_untracked=True)
else:
stash_rev = None
try:
yield
except Exception: # pylint: disable=broad-except
if stash_rev:
_clean_and_pop(repo.scm)
raise
if not stash_rev:
return

try:
repo.scm.reset()
repo.scm.stash.apply(stash_rev, skip_conflicts=force)
repo.scm.stash.drop()
except (MergeConflictError, _SCMError) as exc:
_clean_and_pop(repo.scm)
raise ApplyConflictError(rev) from exc
except Exception: # pylint: disable=broad-except
_clean_and_pop(repo.scm)
raise


def _clean_and_pop(scm: "Git"):
"""Revert any changes and pop the last stash entry."""
scm.reset(hard=True)
if scm.is_dirty(untracked_files=True):
# drop any changes to untracked files before popping stash
scm.stash.push(include_untracked=True)
scm.stash.drop()
scm.stash.pop()
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ install_requires =
typing-extensions>=3.7.4
fsspec[http]>=2021.10.1
aiohttp-retry>=2.4.5
scmrepo==0.0.24
scmrepo==0.0.25
dvc-render==0.0.6
dvclive>=0.7.3
dvc-data==0.0.6
Expand Down
19 changes: 19 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,25 @@ def test_apply(tmp_dir, scm, dvc, exp_stage, queue):
)


def test_apply_untracked(tmp_dir, scm, dvc, exp_stage):
from dvc.repo.experiments.base import ApplyConflictError

results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp = first(results)
tmp_dir.gen("untracked", "untracked")
tmp_dir.gen("params.yaml", "conflict")

with pytest.raises(ApplyConflictError):
dvc.experiments.apply(exp, force=False)

assert (tmp_dir / "untracked").read_text() == "untracked"
assert (tmp_dir / "params.yaml").read_text() == "conflict"

dvc.experiments.apply(exp, force=True)
assert (tmp_dir / "untracked").read_text() == "untracked"
assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 2"


def test_get_baseline(tmp_dir, scm, dvc, exp_stage):
from dvc.repo.experiments.base import EXPS_STASH

Expand Down