Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dvc.dependency.local import LocalDependency
from dvc.exceptions import DvcException
from dvc.utils.serialize import PARSERS, ParseError
from dvc.utils.serialize import LOADERS, ParseError


class MissingParamsError(DvcException):
Expand Down Expand Up @@ -86,14 +86,13 @@ def read_params(self):
return {}

suffix = self.path_info.suffix.lower()
parser = PARSERS[suffix]
with self.repo.tree.open(self.path_info, "r") as fobj:
try:
config = parser(fobj.read(), self.path_info)
except ParseError as exc:
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
) from exc
loader = LOADERS[suffix]
try:
config = loader(self.path_info, tree=self.repo.tree)
except ParseError as exc:
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
) from exc

ret = {}
for param in self.params:
Expand Down
11 changes: 8 additions & 3 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from dvc.stage.loader import SingleStageLoader, StageLoader
from dvc.utils import relpath
from dvc.utils.collections import apply_diff
from dvc.utils.serialize import dump_yaml, parse_yaml, parse_yaml_for_update
from dvc.utils.serialize import (
dump_yaml,
load_yaml,
parse_yaml,
parse_yaml_for_update,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -264,8 +269,8 @@ class Lockfile(FileMixin):
def load(self):
if not self.exists():
return {}
with self.repo.tree.open(self.path) as fd:
data = parse_yaml(fd.read(), self.path)

data = load_yaml(self.path, tree=self.repo.tree)
try:
self.validate(data, fname=self.relpath)
except StageFileFormatError:
Expand Down
6 changes: 2 additions & 4 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dvc.path_info import PathInfo
from dvc.repo import locked
from dvc.tree.repo import RepoTree
from dvc.utils.serialize import YAMLFileCorruptedError, parse_yaml
from dvc.utils.serialize import YAMLFileCorruptedError, load_yaml

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,9 +69,7 @@ def _read_metrics(repo, metrics, rev):
continue

try:
with tree.open(metric, "r") as fobj:
# NOTE this also supports JSON
val = parse_yaml(fobj.read(), metric)
val = load_yaml(metric, tree=tree)
except (FileNotFoundError, YAMLFileCorruptedError):
logger.debug(
"failed to read '%s' on '%s'", metric, rev, exc_info=True
Expand Down
19 changes: 9 additions & 10 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
from dvc.repo import locked
from dvc.utils.serialize import PARSERS, ParseError
from dvc.utils.serialize import LOADERS, ParseError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -32,15 +32,14 @@ def _read_params(repo, configs, rev):
continue

suffix = config.suffix.lower()
parser = PARSERS[suffix]
with repo.tree.open(config, "r") as fobj:
try:
res[str(config)] = parser(fobj.read(), config)
except ParseError:
logger.debug(
"failed to read '%s' on '%s'", config, rev, exc_info=True
)
continue
loader = LOADERS[suffix]
try:
res[str(config)] = loader(config, tree=repo.tree)
except ParseError:
logger.debug(
"failed to read '%s' on '%s'", config, rev, exc_info=True
)
continue

return res

Expand Down
4 changes: 2 additions & 2 deletions dvc/repo/plots/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from copy import copy

from funcy import first
from ruamel.yaml import YAML

from dvc.exceptions import DvcException
from dvc.utils.serialize import loads_yaml


class PlotMetricTypeError(DvcException):
Expand Down Expand Up @@ -207,7 +207,7 @@ def raw(self, header=True, **kwargs): # pylint: disable=arguments-differ

class YAMLPlotData(PlotData):
def raw(self, **kwargs):
return YAML().load(self.content)
return loads_yaml(self.content, typ="rt")

def _processors(self):
parent_processors = super()._processors()
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from ._toml import * # noqa, pylint: disable=wildcard-import
from ._yaml import * # noqa, pylint: disable=wildcard-import

PARSERS = defaultdict(lambda: parse_yaml) # noqa: F405
PARSERS.update({".toml": parse_toml}) # noqa: F405
LOADERS = defaultdict(lambda: load_yaml) # noqa: F405
LOADERS.update({".toml": load_toml}) # noqa: F405
12 changes: 12 additions & 0 deletions dvc/utils/serialize/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@ class ParseError(DvcException):
def __init__(self, path, message):
path = relpath(path)
super().__init__(f"unable to read: '{path}', {message}")


def _load_data(path, parser, tree=None):
open_fn = tree.open if tree else open
with open_fn(path, encoding="utf-8") as fd:
return parser(fd.read(), path)


def _dump_data(path, data, dumper, tree=None):
open_fn = tree.open if tree else open
with open_fn(path, "w+", encoding="utf-8") as fd:
dumper(data, fd)
15 changes: 11 additions & 4 deletions dvc/utils/serialize/_toml.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import toml
from funcy import reraise

from ._common import ParseError
from ._common import ParseError, _dump_data, _load_data


class TOMLFileCorruptedError(ParseError):
def __init__(self, path):
super().__init__(path, "TOML file structure is corrupted")


def load_toml(path, tree=None):
return _load_data(path, parser=parse_toml, tree=tree)


def parse_toml(text, path, decoder=None):
with reraise(toml.TomlDecodeError, TOMLFileCorruptedError(path)):
return toml.loads(text, decoder=decoder)
Expand All @@ -25,6 +29,9 @@ def parse_toml_for_update(text, path):
return parse_toml(text, path, decoder=decoder)


def dump_toml(path, data):
with open(path, "w+", encoding="utf-8") as fobj:
toml.dump(data, fobj, encoder=toml.TomlPreserveCommentEncoder())
def _dump(data, stream):
return toml.dump(data, stream, encoder=toml.TomlPreserveCommentEncoder())


def dump_toml(path, data, tree=None):
return _dump_data(path, data, dumper=_dump, tree=tree)
23 changes: 13 additions & 10 deletions dvc/utils/serialize/_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError

from ._common import ParseError
from ._common import ParseError, _dump_data, _load_data


class YAMLFileCorruptedError(ParseError):
def __init__(self, path):
super().__init__(path, "YAML file structure is corrupted")


def load_yaml(path):
with open(path, encoding="utf-8") as fd:
return parse_yaml(fd.read(), path)
def load_yaml(path, tree=None):
return _load_data(path, parser=parse_yaml, tree=tree)


def parse_yaml(text, path, typ="safe"):
Expand Down Expand Up @@ -46,17 +45,21 @@ def _get_yaml():
return yaml


def dump_yaml(path, data):
def _dump(data, stream):
yaml = _get_yaml()
with open(path, "w+", encoding="utf-8") as fd:
yaml.dump(data, fd)
return yaml.dump(data, stream)


def loads_yaml(s):
return YAML(typ="safe").load(s)
def dump_yaml(path, data, tree=None):
return _dump_data(path, data, dumper=_dump, tree=tree)


def loads_yaml(s, typ="safe"):
return YAML(typ=typ).load(s)


def dumps_yaml(d):
stream = io.StringIO()
YAML().dump(d, stream)
yaml = _get_yaml()
yaml.dump(d, stream)
return stream.getvalue()
15 changes: 7 additions & 8 deletions tests/func/test_repro_multistage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dvc.exceptions import CyclicGraphError
from dvc.main import main
from dvc.stage import PipelineStage
from dvc.utils.serialize import dump_yaml, parse_yaml
from dvc.utils.serialize import dump_yaml, load_yaml
from tests.func import test_repro

COPY_SCRIPT_FORMAT = dedent(
Expand Down Expand Up @@ -457,13 +457,12 @@ def test_cyclic_graph_error(tmp_dir, dvc, run_copy):
run_copy("bar", "baz", name="copy-bar-baz")
run_copy("baz", "foobar", name="copy-baz-foobar")

with open(PIPELINE_FILE) as f:
data = parse_yaml(f.read(), PIPELINE_FILE)
data["stages"]["copy-baz-foo"] = {
"cmd": "echo baz > foo",
"deps": ["baz"],
"outs": ["foo"],
}
data = load_yaml(PIPELINE_FILE)
data["stages"]["copy-baz-foo"] = {
"cmd": "echo baz > foo",
"deps": ["baz"],
"outs": ["foo"],
}
dump_yaml(PIPELINE_FILE, data)
with pytest.raises(CyclicGraphError):
dvc.reproduce(":copy-baz-foo")
Expand Down