diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index a9abbf5661..b911ab8fc4 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -62,7 +62,7 @@ def __init__(self, key, new, into): ) -class ParamsFileNotFound(ContextError): +class ParamsLoadError(ContextError): pass @@ -349,7 +349,9 @@ def select( def load_from(cls, tree, path: PathInfo, select_keys=None) -> "Context": file = relpath(path) if not tree.exists(path): - raise ParamsFileNotFound(f"'{file}' does not exist") + raise ParamsLoadError(f"'{file}' does not exist") + if tree.isdir(path): + raise ParamsLoadError(f"'{file}' is a directory") _, ext = os.path.splitext(file) loader = LOADERS[ext] @@ -357,7 +359,7 @@ def load_from(cls, tree, path: PathInfo, select_keys=None) -> "Context": data = loader(path, tree=tree) if not isinstance(data, Mapping): typ = type(data).__name__ - raise ContextError( + raise ParamsLoadError( f"expected a dictionary, got '{typ}' in file '{file}'" ) @@ -367,7 +369,7 @@ def load_from(cls, tree, path: PathInfo, select_keys=None) -> "Context": data = {key: data[key] for key in select_keys} except KeyError as exc: key, *_ = exc.args - raise ContextError( + raise ParamsLoadError( f"could not find '{key}' in '{file}'" ) from exc diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 2cab30d18f..eae4dc056e 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -1,5 +1,6 @@ from dataclasses import asdict from math import pi +from unittest.mock import mock_open import pytest @@ -10,13 +11,13 @@ CtxList, KeyNotInContext, MergeError, - ParamsFileNotFound, + ParamsLoadError, Value, recurse_not_a_node, ) from dvc.tree.local import LocalTree from dvc.utils import relpath -from dvc.utils.serialize import dump_yaml +from dvc.utils.serialize import dump_yaml, dumps_yaml def test_context(): @@ -216,17 +217,13 @@ def test_overwrite_with_setitem(): def test_load_from(mocker): - def _yaml_load(*args, **kwargs): - return {"x": {"y": {"z": 5}, "lst": [1, 2, 3]}, "foo": "foo"} - - mocker.patch("dvc.parsing.context.LOADERS", {".yaml": _yaml_load}) - - class tree: - def exists(self, _): - return True - + d = {"x": {"y": {"z": 5}, "lst": [1, 2, 3]}, "foo": "foo"} + tree = mocker.Mock( + open=mock_open(read_data=dumps_yaml(d)), + **{"exists.return_value": True, "isdir.return_value": False}, + ) file = "params.yaml" - c = Context.load_from(tree(), file) + c = Context.load_from(tree, file) assert asdict(c["x"].meta) == { "source": file, @@ -430,7 +427,18 @@ def test_resolve_resolves_boolean_value(): assert context.resolve_str("--flag ${disabled}") == "--flag false" -def test_merge_from_raises_if_file_not_exist(tmp_dir, dvc): - context = Context(foo="bar") - with pytest.raises(ParamsFileNotFound): - context.merge_from(dvc.tree, DEFAULT_PARAMS_FILE, tmp_dir) +def test_load_from_raises_if_file_not_exist(tmp_dir, dvc): + with pytest.raises(ParamsLoadError) as exc_info: + Context.load_from(dvc.tree, tmp_dir / DEFAULT_PARAMS_FILE) + + assert str(exc_info.value) == "'params.yaml' does not exist" + + +def test_load_from_raises_if_file_is_directory(tmp_dir, dvc): + data_dir = tmp_dir / "data" + data_dir.mkdir() + + with pytest.raises(ParamsLoadError) as exc_info: + Context.load_from(dvc.tree, data_dir) + + assert str(exc_info.value) == "'data' is a directory"