From 0f452f06d9a5cc32f8d28f476eace4164ddd094b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Wed, 19 Aug 2020 19:35:14 +0545 Subject: [PATCH] Make serializing utils use tree, misc changes * Also fixes plots not getting dumped on flow style * Make serialization utils aware of trees * Directly load yamls/tomls using utils, rather than using parser --- dvc/dependency/param.py | 17 ++++++++--------- dvc/dvcfile.py | 11 ++++++++--- dvc/repo/metrics/show.py | 6 ++---- dvc/repo/params/show.py | 19 +++++++++---------- dvc/repo/plots/data.py | 4 ++-- dvc/utils/serialize/__init__.py | 4 ++-- dvc/utils/serialize/_common.py | 12 ++++++++++++ dvc/utils/serialize/_toml.py | 15 +++++++++++---- dvc/utils/serialize/_yaml.py | 23 +++++++++++++---------- tests/func/test_repro_multistage.py | 15 +++++++-------- 10 files changed, 74 insertions(+), 52 deletions(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 17f7ab2b19..d08610bf50 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -6,7 +6,7 @@ from dvc.dependency.local import LocalDependency from dvc.exceptions import DvcException -from dvc.utils.serialize import PARSERS, ParseError +from dvc.utils.serialize import LOADERS, ParseError class MissingParamsError(DvcException): @@ -86,14 +86,13 @@ def read_params(self): return {} suffix = self.path_info.suffix.lower() - parser = PARSERS[suffix] - with self.repo.tree.open(self.path_info, "r") as fobj: - try: - config = parser(fobj.read(), self.path_info) - except ParseError as exc: - raise BadParamFileError( - f"Unable to read parameters from '{self}'" - ) from exc + loader = LOADERS[suffix] + try: + config = loader(self.path_info, tree=self.repo.tree) + except ParseError as exc: + raise BadParamFileError( + f"Unable to read parameters from '{self}'" + ) from exc ret = {} for param in self.params: diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index f523192544..4c1af7f739 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -16,7 +16,12 @@ from dvc.stage.loader import SingleStageLoader, StageLoader from dvc.utils import relpath from dvc.utils.collections import apply_diff -from dvc.utils.serialize import dump_yaml, parse_yaml, parse_yaml_for_update +from dvc.utils.serialize import ( + dump_yaml, + load_yaml, + parse_yaml, + parse_yaml_for_update, +) logger = logging.getLogger(__name__) @@ -264,8 +269,8 @@ class Lockfile(FileMixin): def load(self): if not self.exists(): return {} - with self.repo.tree.open(self.path) as fd: - data = parse_yaml(fd.read(), self.path) + + data = load_yaml(self.path, tree=self.repo.tree) try: self.validate(data, fname=self.relpath) except StageFileFormatError: diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index 1921740731..0913a59719 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -5,7 +5,7 @@ from dvc.path_info import PathInfo from dvc.repo import locked from dvc.tree.repo import RepoTree -from dvc.utils.serialize import YAMLFileCorruptedError, parse_yaml +from dvc.utils.serialize import YAMLFileCorruptedError, load_yaml logger = logging.getLogger(__name__) @@ -69,9 +69,7 @@ def _read_metrics(repo, metrics, rev): continue try: - with tree.open(metric, "r") as fobj: - # NOTE this also supports JSON - val = parse_yaml(fobj.read(), metric) + val = load_yaml(metric, tree=tree) except (FileNotFoundError, YAMLFileCorruptedError): logger.debug( "failed to read '%s' on '%s'", metric, rev, exc_info=True diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index a7803ff6ff..7b2f59457c 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -4,7 +4,7 @@ from dvc.exceptions import DvcException from dvc.path_info import PathInfo from dvc.repo import locked -from dvc.utils.serialize import PARSERS, ParseError +from dvc.utils.serialize import LOADERS, ParseError logger = logging.getLogger(__name__) @@ -32,15 +32,14 @@ def _read_params(repo, configs, rev): continue suffix = config.suffix.lower() - parser = PARSERS[suffix] - with repo.tree.open(config, "r") as fobj: - try: - res[str(config)] = parser(fobj.read(), config) - except ParseError: - logger.debug( - "failed to read '%s' on '%s'", config, rev, exc_info=True - ) - continue + loader = LOADERS[suffix] + try: + res[str(config)] = loader(config, tree=repo.tree) + except ParseError: + logger.debug( + "failed to read '%s' on '%s'", config, rev, exc_info=True + ) + continue return res diff --git a/dvc/repo/plots/data.py b/dvc/repo/plots/data.py index 02caca70f4..a8cbc4ef5c 100644 --- a/dvc/repo/plots/data.py +++ b/dvc/repo/plots/data.py @@ -6,9 +6,9 @@ from copy import copy from funcy import first -from ruamel.yaml import YAML from dvc.exceptions import DvcException +from dvc.utils.serialize import loads_yaml class PlotMetricTypeError(DvcException): @@ -207,7 +207,7 @@ def raw(self, header=True, **kwargs): # pylint: disable=arguments-differ class YAMLPlotData(PlotData): def raw(self, **kwargs): - return YAML().load(self.content) + return loads_yaml(self.content, typ="rt") def _processors(self): parent_processors = super()._processors() diff --git a/dvc/utils/serialize/__init__.py b/dvc/utils/serialize/__init__.py index ab658a5c53..33b85159e3 100644 --- a/dvc/utils/serialize/__init__.py +++ b/dvc/utils/serialize/__init__.py @@ -4,5 +4,5 @@ from ._toml import * # noqa, pylint: disable=wildcard-import from ._yaml import * # noqa, pylint: disable=wildcard-import -PARSERS = defaultdict(lambda: parse_yaml) # noqa: F405 -PARSERS.update({".toml": parse_toml}) # noqa: F405 +LOADERS = defaultdict(lambda: load_yaml) # noqa: F405 +LOADERS.update({".toml": load_toml}) # noqa: F405 diff --git a/dvc/utils/serialize/_common.py b/dvc/utils/serialize/_common.py index 16c5b5e1a1..10dbab3bcf 100644 --- a/dvc/utils/serialize/_common.py +++ b/dvc/utils/serialize/_common.py @@ -10,3 +10,15 @@ class ParseError(DvcException): def __init__(self, path, message): path = relpath(path) super().__init__(f"unable to read: '{path}', {message}") + + +def _load_data(path, parser, tree=None): + open_fn = tree.open if tree else open + with open_fn(path, encoding="utf-8") as fd: + return parser(fd.read(), path) + + +def _dump_data(path, data, dumper, tree=None): + open_fn = tree.open if tree else open + with open_fn(path, "w+", encoding="utf-8") as fd: + dumper(data, fd) diff --git a/dvc/utils/serialize/_toml.py b/dvc/utils/serialize/_toml.py index 7c002ba7df..916c779d45 100644 --- a/dvc/utils/serialize/_toml.py +++ b/dvc/utils/serialize/_toml.py @@ -1,7 +1,7 @@ import toml from funcy import reraise -from ._common import ParseError +from ._common import ParseError, _dump_data, _load_data class TOMLFileCorruptedError(ParseError): @@ -9,6 +9,10 @@ def __init__(self, path): super().__init__(path, "TOML file structure is corrupted") +def load_toml(path, tree=None): + return _load_data(path, parser=parse_toml, tree=tree) + + def parse_toml(text, path, decoder=None): with reraise(toml.TomlDecodeError, TOMLFileCorruptedError(path)): return toml.loads(text, decoder=decoder) @@ -25,6 +29,9 @@ def parse_toml_for_update(text, path): return parse_toml(text, path, decoder=decoder) -def dump_toml(path, data): - with open(path, "w+", encoding="utf-8") as fobj: - toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder()) +def _dump(data, stream): + return toml.dump(data, stream, encoder=toml.TomlPreserveCommentEncoder()) + + +def dump_toml(path, data, tree=None): + return _dump_data(path, data, dumper=_dump, tree=tree) diff --git a/dvc/utils/serialize/_yaml.py b/dvc/utils/serialize/_yaml.py index f60f7c64a2..ec04ac30b6 100644 --- a/dvc/utils/serialize/_yaml.py +++ b/dvc/utils/serialize/_yaml.py @@ -5,7 +5,7 @@ from ruamel.yaml import YAML from ruamel.yaml.error import YAMLError -from ._common import ParseError +from ._common import ParseError, _dump_data, _load_data class YAMLFileCorruptedError(ParseError): @@ -13,9 +13,8 @@ def __init__(self, path): super().__init__(path, "YAML file structure is corrupted") -def load_yaml(path): - with open(path, encoding="utf-8") as fd: - return parse_yaml(fd.read(), path) +def load_yaml(path, tree=None): + return _load_data(path, parser=parse_yaml, tree=tree) def parse_yaml(text, path, typ="safe"): @@ -46,17 +45,21 @@ def _get_yaml(): return yaml -def dump_yaml(path, data): +def _dump(data, stream): yaml = _get_yaml() - with open(path, "w+", encoding="utf-8") as fd: - yaml.dump(data, fd) + return yaml.dump(data, stream) -def loads_yaml(s): - return YAML(typ="safe").load(s) +def dump_yaml(path, data, tree=None): + return _dump_data(path, data, dumper=_dump, tree=tree) + + +def loads_yaml(s, typ="safe"): + return YAML(typ=typ).load(s) def dumps_yaml(d): stream = io.StringIO() - YAML().dump(d, stream) + yaml = _get_yaml() + yaml.dump(d, stream) return stream.getvalue() diff --git a/tests/func/test_repro_multistage.py b/tests/func/test_repro_multistage.py index 37ba758d61..4916d66c3f 100644 --- a/tests/func/test_repro_multistage.py +++ b/tests/func/test_repro_multistage.py @@ -9,7 +9,7 @@ from dvc.exceptions import CyclicGraphError from dvc.main import main from dvc.stage import PipelineStage -from dvc.utils.serialize import dump_yaml, parse_yaml +from dvc.utils.serialize import dump_yaml, load_yaml from tests.func import test_repro COPY_SCRIPT_FORMAT = dedent( @@ -457,13 +457,12 @@ def test_cyclic_graph_error(tmp_dir, dvc, run_copy): run_copy("bar", "baz", name="copy-bar-baz") run_copy("baz", "foobar", name="copy-baz-foobar") - with open(PIPELINE_FILE) as f: - data = parse_yaml(f.read(), PIPELINE_FILE) - data["stages"]["copy-baz-foo"] = { - "cmd": "echo baz > foo", - "deps": ["baz"], - "outs": ["foo"], - } + data = load_yaml(PIPELINE_FILE) + data["stages"]["copy-baz-foo"] = { + "cmd": "echo baz > foo", + "deps": ["baz"], + "outs": ["foo"], + } dump_yaml(PIPELINE_FILE, data) with pytest.raises(CyclicGraphError): dvc.reproduce(":copy-baz-foo")