Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "INFO").upper())

ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]]
StrPath = Union[str, Path]


class Live:
Expand All @@ -57,6 +58,7 @@ def __init__(
self._images: Dict[str, Any] = {}
self._params: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}
self._outs: Set[StrPath] = set()
self._inside_with = False
self._dvcyaml = dvcyaml

Expand All @@ -76,6 +78,7 @@ def __init__(
self._experiment_rev: Optional[str] = None
self._inside_dvc_exp: bool = False
self._dvc_repo = None
self._include_untracked: List[str] = []
self._init_dvc()

self._latest_studio_step = self.step if resume else -1
Expand Down Expand Up @@ -131,6 +134,7 @@ def _init_dvc(self):
if self._save_dvc_exp:
self._exp_name = get_random_exp_name(self._dvc_repo.scm, self._baseline_rev)
mark_dvclive_only_started()
self._include_untracked.append(self.dir)

def _init_studio(self):
if not os.getenv(STUDIO_TOKEN, None):
Expand Down Expand Up @@ -300,14 +304,31 @@ def log_params(self, params: Dict[str, ParamLike]):
self._dump_params()
logger.debug(f"Logged {params} parameters to {self.params_file}")

def log_param(
self,
name: str,
val: ParamLike,
):
def log_param(self, name: str, val: ParamLike):
"""Saves the given parameter value to yaml"""
self.log_params({name: val})

def log_artifact(self, path: StrPath):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
raise InvalidDataTypeError(path, type(path))

if self._dvc_repo is not None:
try:
stage = self._dvc_repo.add(path)
except Exception as e: # pylint: disable=broad-except
logger.warning(f"Failed to dvc add {path}: {e}")
return

self._outs.add(path)
dvc_file = stage[0].addressing

if self._save_dvc_exp:
self._include_untracked.append(dvc_file)
self._include_untracked.append(
str(Path(dvc_file).parent / ".gitignore")
)

def make_summary(self, update_step: bool = True):
if self._step is not None and update_step:
self.summary["step"] = self.step
Expand Down Expand Up @@ -378,7 +399,9 @@ def end(self):

try:
self._experiment_rev = self._dvc_repo.experiments.save(
name=self._exp_name, include_untracked=[self.dir], force=True
name=self._exp_name,
include_untracked=self._include_untracked,
force=True,
)
except DvcException as e:
logger.warning(f"Failed to save experiment:\n{e}")
Expand Down
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@ def tmp_dir(tmp_path, monkeypatch):
yield tmp_path


@pytest.fixture
def mocked_dvc_repo(mocker):
_dvc_repo = mocker.MagicMock()
_dvc_repo.index.stages = []
_dvc_repo.scm.get_rev.return_value = "current_rev"
_dvc_repo.scm.get_ref.return_value = None
mocker.patch("dvclive.live.get_dvc_repo", return_value=_dvc_repo)
return _dvc_repo


@pytest.fixture
def dvc_repo(tmp_dir): # pylint: disable=redefined-outer-name
from dvc.repo import Repo
from scmrepo.git import Git

Git.init(tmp_dir)
repo = Repo.init(tmp_dir)
repo.scm.add_commit(".", "init")
return repo


@pytest.fixture(autouse=True)
def capture_wrap():
# https://github.com/pytest-dev/pytest/issues/5502#issuecomment-678368525
Expand Down
42 changes: 15 additions & 27 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=unused-argument,protected-access

import pytest
from dvc.repo import Repo
from PIL import Image
Expand Down Expand Up @@ -85,24 +86,21 @@ def test_make_dvcyaml_all_plots(tmp_dir):


@pytest.mark.parametrize("save", [True, False])
def test_exp_save_on_end(tmp_dir, mocker, save):
dvc_repo = mocker.MagicMock()
dvc_repo.index.stages = []
dvc_repo.scm.get_rev.return_value = "current_rev"
dvc_repo.scm.get_ref.return_value = None
with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo):
live = Live(save_dvc_exp=save)
live.end()
def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo):
live = Live(save_dvc_exp=save)
live.end()
if save:
assert live._baseline_rev is not None
assert live._exp_name != "dvclive-exp"
dvc_repo.experiments.save.assert_called_with(
name=live._exp_name, include_untracked=[live.dir], force=True
mocked_dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir],
force=True,
)
else:
assert live._baseline_rev is not None
assert live._exp_name == "dvclive-exp"
dvc_repo.experiments.save.assert_not_called()
mocked_dvc_repo.experiments.save.assert_not_called()


def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
Expand Down Expand Up @@ -139,29 +137,19 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker):


@pytest.mark.parametrize("dvcyaml", [True, False])
def test_dvcyaml_on_next_step(tmp_dir, mocker, dvcyaml):
dvc_repo = mocker.MagicMock()
dvc_repo.index.stages = []
dvc_repo.scm.get_rev.return_value = "current_rev"
dvc_repo.scm.get_ref.return_value = None
with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo):
live = Live(dvcyaml=dvcyaml)
live.next_step()
def test_dvcyaml_on_next_step(tmp_dir, dvcyaml, mocked_dvc_repo):
live = Live(dvcyaml=dvcyaml)
live.next_step()
if dvcyaml:
assert (tmp_dir / live.dvc_file).exists()
else:
assert not (tmp_dir / live.dvc_file).exists()


@pytest.mark.parametrize("dvcyaml", [True, False])
def test_dvcyaml_on_end(tmp_dir, mocker, dvcyaml):
dvc_repo = mocker.MagicMock()
dvc_repo.index.stages = []
dvc_repo.scm.get_rev.return_value = "current_rev"
dvc_repo.scm.get_ref.return_value = None
with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo):
live = Live(dvcyaml=dvcyaml)
live.end()
def test_dvcyaml_on_end(tmp_dir, dvcyaml, mocked_dvc_repo):
live = Live(dvcyaml=dvcyaml)
live.end()
if dvcyaml:
assert (tmp_dir / live.dvc_file).exists()
else:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# pylint: disable=unused-argument,protected-access
from dvclive import Live


def test_log_artifact(tmp_dir, dvc_repo):
data = tmp_dir / "data"
data.touch()
with Live() as live:
live.log_artifact("data")
assert data.with_suffix(".dvc").exists()


def test_log_artifact_on_existing_dvc_file(tmp_dir, dvc_repo):
data = tmp_dir / "data"
data.write_text("foo")
with Live() as live:
live.log_artifact("data")

prev_content = data.with_suffix(".dvc").read_text()

with Live() as live:
data.write_text("bar")
live.log_artifact("data")

assert data.with_suffix(".dvc").read_text() != prev_content


def test_log_artifact_twice(tmp_dir, dvc_repo):
data = tmp_dir / "data"
with Live() as live:
for i in range(2):
data.write_text(str(i))
live.log_artifact("data")
assert data.with_suffix(".dvc").exists()


def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo):
stage = mocker.MagicMock()
stage.addressing = "data"
mocked_dvc_repo.add.return_value = [stage]
with Live(save_dvc_exp=True) as live:
live.log_artifact("data")
mocked_dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, "data", ".gitignore"],
force=True,
)