diff --git a/dvclive/env.py b/dvclive/env.py index 4ec2b06d..6daa2932 100644 --- a/dvclive/env.py +++ b/dvclive/env.py @@ -1,3 +1,4 @@ +DVCLIVE_OPEN = "DVCLIVE_OPEN" DVCLIVE_PATH = "DVCLIVE_PATH" DVCLIVE_SUMMARY = "DVCLIVE_SUMMARY" DVCLIVE_HTML = "DVCLIVE_HTML" diff --git a/dvclive/live.py b/dvclive/live.py index 67c6de40..c4f4cd16 100644 --- a/dvclive/live.py +++ b/dvclive/live.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Union +from . import env from .data import DATA_TYPES, PLOTS, Image, Scalar from .dvc import make_checkpoint from .error import ( @@ -15,7 +16,7 @@ InvalidPlotTypeError, ) from .report import html_report -from .utils import nested_update, open_file_in_browser +from .utils import env2bool, nested_update, open_file_in_browser logger = logging.getLogger(__name__) @@ -28,14 +29,12 @@ def __init__( path: Optional[str] = None, resume: bool = False, report: Optional[str] = "html", - auto_open: bool = False, ): self._path: Optional[str] = path self._resume: bool = resume self._report: str = report self._checkpoint: bool = False - self._auto_open: bool = auto_open self.html_path = None self.init_from_env() @@ -77,23 +76,19 @@ def _init_paths(self): os.makedirs(self.dir, exist_ok=True) def init_from_env(self) -> None: - from . import env + if os.getenv(env.DVCLIVE_PATH): - if env.DVCLIVE_PATH in os.environ: - - if self.dir and self.dir != os.environ[env.DVCLIVE_PATH]: + if self.dir and self.dir != os.getenv(env.DVCLIVE_PATH): raise ConfigMismatchError(self) env_config = { - "_path": os.environ.get(env.DVCLIVE_PATH), - "_checkpoint": bool( - int(os.environ.get(env.DVC_CHECKPOINT, "0")) - ), - "_resume": bool(int(os.environ.get(env.DVCLIVE_RESUME, "0"))), + "_path": os.getenv(env.DVCLIVE_PATH), + "_checkpoint": env2bool(env.DVC_CHECKPOINT), + "_resume": env2bool(env.DVCLIVE_RESUME), } # Keeping backward compatibility with `live` section - if not bool(int(os.environ.get(env.DVCLIVE_HTML, "0"))): + if not env2bool(env.DVCLIVE_HTML, "0"): env_config["_report"] = None else: path = str(env_config["_path"]) @@ -196,9 +191,8 @@ def make_summary(self): def make_report(self): if self._report == "html": html_report(self.dir, self.summary_path, self.html_path) - if self._auto_open: + if env2bool(env.DVCLIVE_OPEN): open_file_in_browser(self.html_path) - self._auto_open = False def read_step(self): if Path(self.summary_path).exists(): diff --git a/dvclive/utils.py b/dvclive/utils.py index 9c881c41..25f0bb64 100644 --- a/dvclive/utils.py +++ b/dvclive/utils.py @@ -1,5 +1,7 @@ import base64 import csv +import os +import re import webbrowser from collections.abc import Mapping from pathlib import Path @@ -45,8 +47,29 @@ def to_base64_url(image_file): return f"data:image;base64,{base64_str}" +def run_once(f): + def wrapper(*args, **kwargs): + if not wrapper.has_run: + wrapper.has_run = True + return f(*args, **kwargs) + + wrapper.has_run = False + return wrapper + + +@run_once def open_file_in_browser(file) -> bool: path = Path(file) url = path if "Microsoft" in uname().release else path.resolve().as_uri() return webbrowser.open(url) + + +def env2bool(var, undefined=False): + """ + undefined: return value if env var is unset + """ + var = os.getenv(var, None) + if var is None: + return undefined + return bool(re.search("1|y|yes|true", var, flags=re.I)) diff --git a/tests/test_report.py b/tests/test_report.py index 664146f5..167a336b 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -1,12 +1,14 @@ # pylint: disable=unused-argument import os +import pytest from PIL import Image from dvclive import Live from dvclive.data import Image as LiveImage from dvclive.data import Scalar from dvclive.data.plot import ConfusionMatrix, Plot +from dvclive.env import DVCLIVE_OPEN from dvclive.report import ( get_image_renderers, get_plot_renderers, @@ -58,7 +60,8 @@ def test_get_renderers(tmp_dir, mocker): assert plot_renderers[0].properties == ConfusionMatrix.get_properties() -def test_make_report_open(tmp_dir, mocker): +@pytest.mark.vscode +def test_make_report_open(tmp_dir, mocker, monkeypatch): mocked_open = mocker.patch("webbrowser.open") live = Live() live.log_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) @@ -67,15 +70,15 @@ def test_make_report_open(tmp_dir, mocker): assert not mocked_open.called - mocked_open = mocker.patch("webbrowser.open") live = Live(report=None) live.log("foo", 1) live.next_step() assert not mocked_open.called - mocked_open = mocker.patch("webbrowser.open") - live = Live(auto_open=True) + monkeypatch.setenv(DVCLIVE_OPEN, True) + + live = Live() live.log_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) live.make_report()