Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict

import dpath.util
import toml
import yaml
from voluptuous import Any

Expand All @@ -21,6 +22,8 @@ class ParamsDependency(LocalDependency):
PARAM_PARAMS = "params"
PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)}
DEFAULT_PARAMS_FILE = "params.yaml"
PARAMS_FILE_LOADERS = defaultdict(lambda: yaml.safe_load)
PARAMS_FILE_LOADERS.update({".toml": toml.load})

def __init__(self, stage, path, params):
info = {}
Expand Down Expand Up @@ -87,8 +90,10 @@ def read_params(self):

with self.repo.tree.open(self.path_info, "r") as fobj:
try:
config = yaml.safe_load(fobj)
except yaml.YAMLError as exc:
config = self.PARAMS_FILE_LOADERS[
self.path_info.suffix.lower()
](fobj)
except (yaml.YAMLError, toml.TomlDecodeError) as exc:
raise BadParamFileError(
f"Unable to read parameters from '{self}'"
) from exc
Expand Down
7 changes: 5 additions & 2 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import toml
import yaml

from dvc.dependency.param import ParamsDependency
Expand Down Expand Up @@ -34,8 +35,10 @@ def _read_params(repo, configs, rev):

with repo.tree.open(config, "r") as fobj:
try:
res[str(config)] = yaml.safe_load(fobj)
except yaml.YAMLError:
res[str(config)] = ParamsDependency.PARAMS_FILE_LOADERS[
config.suffix.lower()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

](fobj)
except (yaml.YAMLError, toml.TomlDecodeError):
logger.debug(
"failed to read '%s' on '%s'", config, rev, exc_info=True
)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ count=true
[isort]
include_trailing_comma=true
known_first_party=dvc,tests
known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,tqdm,voluptuous,yaml,zc
known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pygtrie,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,toml,tqdm,voluptuous,yaml,zc
line_length=79
force_grid_wrap=0
use_parentheses=True
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def run(self):
"appdirs>=1.4.3",
"PyYAML>=5.1.2,<5.4", # Compatibility with awscli
"ruamel.yaml>=0.16.1",
"toml>=0.10.1",
"funcy>=1.14",
"pathspec>=0.6.0",
"shortuuid>=0.5.0",
Expand Down
10 changes: 10 additions & 0 deletions tests/func/params/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ def test_show(tmp_dir, dvc):
assert dvc.params.show() == {"": {"params.yaml": {"foo": "bar"}}}


def test_show_toml(tmp_dir, dvc):
tmp_dir.gen("params.toml", "[foo]\nbar = 42\nbaz = [1, 2]\n")
dvc.run(
cmd="echo params.toml", params=["params.toml:foo"], single_stage=True
)
assert dvc.params.show() == {
"": {"params.toml": {"foo": {"bar": 42, "baz": [1, 2]}}}
}


def test_show_multiple(tmp_dir, dvc):
tmp_dir.gen("params.yaml", "foo: bar\nbaz: qux\n")
dvc.run(
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/dependency/test_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import toml
import yaml

from dvc.dependency import ParamsDependency, loadd_from, loads_params
Expand Down Expand Up @@ -99,6 +100,37 @@ def test_read_params_nested(tmp_dir, dvc):
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}


def test_read_params_default_loader(tmp_dir, dvc):
parameters_file = "parameters.foo"
tmp_dir.gen(
parameters_file,
yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}),
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}


def test_read_params_wrong_suffix(tmp_dir, dvc):
parameters_file = "parameters.toml"
tmp_dir.gen(
parameters_file,
yaml.dump({"some": {"path": {"foo": ["val1", "val2"]}}}),
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
with pytest.raises(BadParamFileError):
dep.read_params()


def test_read_params_toml(tmp_dir, dvc):
parameters_file = "parameters.toml"
tmp_dir.gen(
parameters_file,
toml.dumps({"some": {"path": {"foo": ["val1", "val2"]}}}),
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["some.path.foo"])
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}


def test_save_info_missing_config(dvc):
dep = ParamsDependency(Stage(dvc), None, ["foo"])
with pytest.raises(MissingParamsError):
Expand Down