diff --git a/src/dvclive/live.py b/src/dvclive/live.py index b4310fcb..e641452d 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -36,6 +36,7 @@ logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "INFO").upper()) ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]] +StrPath = Union[str, Path] class Live: @@ -57,6 +58,7 @@ def __init__( self._images: Dict[str, Any] = {} self._params: Dict[str, Any] = {} self._plots: Dict[str, Any] = {} + self._outs: Set[StrPath] = set() self._inside_with = False self._dvcyaml = dvcyaml @@ -76,6 +78,7 @@ def __init__( self._experiment_rev: Optional[str] = None self._inside_dvc_exp: bool = False self._dvc_repo = None + self._include_untracked: List[str] = [] self._init_dvc() self._latest_studio_step = self.step if resume else -1 @@ -131,6 +134,7 @@ def _init_dvc(self): if self._save_dvc_exp: self._exp_name = get_random_exp_name(self._dvc_repo.scm, self._baseline_rev) mark_dvclive_only_started() + self._include_untracked.append(self.dir) def _init_studio(self): if not os.getenv(STUDIO_TOKEN, None): @@ -300,14 +304,31 @@ def log_params(self, params: Dict[str, ParamLike]): self._dump_params() logger.debug(f"Logged {params} parameters to {self.params_file}") - def log_param( - self, - name: str, - val: ParamLike, - ): + def log_param(self, name: str, val: ParamLike): """Saves the given parameter value to yaml""" self.log_params({name: val}) + def log_artifact(self, path: StrPath): + """Tracks a local file or directory with DVC""" + if not isinstance(path, (str, Path)): + raise InvalidDataTypeError(path, type(path)) + + if self._dvc_repo is not None: + try: + stage = self._dvc_repo.add(path) + except Exception as e: # pylint: disable=broad-except + logger.warning(f"Failed to dvc add {path}: {e}") + return + + self._outs.add(path) + dvc_file = stage[0].addressing + + if self._save_dvc_exp: + self._include_untracked.append(dvc_file) + self._include_untracked.append( + str(Path(dvc_file).parent / ".gitignore") + ) + def make_summary(self, update_step: bool = True): if self._step is not None and update_step: self.summary["step"] = self.step @@ -378,7 +399,9 @@ def end(self): try: self._experiment_rev = self._dvc_repo.experiments.save( - name=self._exp_name, include_untracked=[self.dir], force=True + name=self._exp_name, + include_untracked=self._include_untracked, + force=True, ) except DvcException as e: logger.warning(f"Failed to save experiment:\n{e}") diff --git a/tests/conftest.py b/tests/conftest.py index d4d208ef..1a3ec47b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,27 @@ def tmp_dir(tmp_path, monkeypatch): yield tmp_path +@pytest.fixture +def mocked_dvc_repo(mocker): + _dvc_repo = mocker.MagicMock() + _dvc_repo.index.stages = [] + _dvc_repo.scm.get_rev.return_value = "current_rev" + _dvc_repo.scm.get_ref.return_value = None + mocker.patch("dvclive.live.get_dvc_repo", return_value=_dvc_repo) + return _dvc_repo + + +@pytest.fixture +def dvc_repo(tmp_dir): # pylint: disable=redefined-outer-name + from dvc.repo import Repo + from scmrepo.git import Git + + Git.init(tmp_dir) + repo = Repo.init(tmp_dir) + repo.scm.add_commit(".", "init") + return repo + + @pytest.fixture(autouse=True) def capture_wrap(): # https://github.com/pytest-dev/pytest/issues/5502#issuecomment-678368525 diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 1a0940e7..307d3962 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -1,4 +1,5 @@ # pylint: disable=unused-argument,protected-access + import pytest from dvc.repo import Repo from PIL import Image @@ -85,24 +86,21 @@ def test_make_dvcyaml_all_plots(tmp_dir): @pytest.mark.parametrize("save", [True, False]) -def test_exp_save_on_end(tmp_dir, mocker, save): - dvc_repo = mocker.MagicMock() - dvc_repo.index.stages = [] - dvc_repo.scm.get_rev.return_value = "current_rev" - dvc_repo.scm.get_ref.return_value = None - with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo): - live = Live(save_dvc_exp=save) - live.end() +def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): + live = Live(save_dvc_exp=save) + live.end() if save: assert live._baseline_rev is not None assert live._exp_name != "dvclive-exp" - dvc_repo.experiments.save.assert_called_with( - name=live._exp_name, include_untracked=[live.dir], force=True + mocked_dvc_repo.experiments.save.assert_called_with( + name=live._exp_name, + include_untracked=[live.dir], + force=True, ) else: assert live._baseline_rev is not None assert live._exp_name == "dvclive-exp" - dvc_repo.experiments.save.assert_not_called() + mocked_dvc_repo.experiments.save.assert_not_called() def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker): @@ -139,14 +137,9 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker): @pytest.mark.parametrize("dvcyaml", [True, False]) -def test_dvcyaml_on_next_step(tmp_dir, mocker, dvcyaml): - dvc_repo = mocker.MagicMock() - dvc_repo.index.stages = [] - dvc_repo.scm.get_rev.return_value = "current_rev" - dvc_repo.scm.get_ref.return_value = None - with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo): - live = Live(dvcyaml=dvcyaml) - live.next_step() +def test_dvcyaml_on_next_step(tmp_dir, dvcyaml, mocked_dvc_repo): + live = Live(dvcyaml=dvcyaml) + live.next_step() if dvcyaml: assert (tmp_dir / live.dvc_file).exists() else: @@ -154,14 +147,9 @@ def test_dvcyaml_on_next_step(tmp_dir, mocker, dvcyaml): @pytest.mark.parametrize("dvcyaml", [True, False]) -def test_dvcyaml_on_end(tmp_dir, mocker, dvcyaml): - dvc_repo = mocker.MagicMock() - dvc_repo.index.stages = [] - dvc_repo.scm.get_rev.return_value = "current_rev" - dvc_repo.scm.get_ref.return_value = None - with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo): - live = Live(dvcyaml=dvcyaml) - live.end() +def test_dvcyaml_on_end(tmp_dir, dvcyaml, mocked_dvc_repo): + live = Live(dvcyaml=dvcyaml) + live.end() if dvcyaml: assert (tmp_dir / live.dvc_file).exists() else: diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py new file mode 100644 index 00000000..624ea33e --- /dev/null +++ b/tests/test_log_artifact.py @@ -0,0 +1,47 @@ +# pylint: disable=unused-argument,protected-access +from dvclive import Live + + +def test_log_artifact(tmp_dir, dvc_repo): + data = tmp_dir / "data" + data.touch() + with Live() as live: + live.log_artifact("data") + assert data.with_suffix(".dvc").exists() + + +def test_log_artifact_on_existing_dvc_file(tmp_dir, dvc_repo): + data = tmp_dir / "data" + data.write_text("foo") + with Live() as live: + live.log_artifact("data") + + prev_content = data.with_suffix(".dvc").read_text() + + with Live() as live: + data.write_text("bar") + live.log_artifact("data") + + assert data.with_suffix(".dvc").read_text() != prev_content + + +def test_log_artifact_twice(tmp_dir, dvc_repo): + data = tmp_dir / "data" + with Live() as live: + for i in range(2): + data.write_text(str(i)) + live.log_artifact("data") + assert data.with_suffix(".dvc").exists() + + +def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo): + stage = mocker.MagicMock() + stage.addressing = "data" + mocked_dvc_repo.add.return_value = [stage] + with Live(save_dvc_exp=True) as live: + live.log_artifact("data") + mocked_dvc_repo.experiments.save.assert_called_with( + name=live._exp_name, + include_untracked=[live.dir, "data", ".gitignore"], + force=True, + )