diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 1fa2073e24..10a2f787a6 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -87,26 +87,29 @@ def loads_from(stage, s_list, erepo=None): return ret -def _parse_params(path_params): - path, _, params_str = path_params.rpartition(":") - params = params_str.split(",") - return path, params +def _merge_params(s_list): + d = defaultdict(list) + default_file = ParamsDependency.DEFAULT_PARAMS_FILE + for key in s_list: + if isinstance(key, str): + d[default_file].append(key) + continue + if not isinstance(key, dict): + msg = "Only list of str/dict is supported. Got: " + msg += f"'{type(key).__name__}'." + raise ValueError(msg) + + for k, params in key.items(): + if not isinstance(params, list): + msg = "Expected list of params for custom params file " + msg += f"'{k}', got '{type(params).__name__}'." + raise ValueError(msg) + d[k].extend(params) + return d def loads_params(stage, s_list): - # Creates an object for each unique file that is referenced in the list - params_by_path = defaultdict(list) - for s in s_list: - path, params = _parse_params(s) - params_by_path[path].extend(params) - - d_list = [] - for path, params in params_by_path.items(): - d_list.append( - { - BaseOutput.PARAM_PATH: path, - ParamsDependency.PARAM_PARAMS: params, - } - ) - - return loadd_from(stage, d_list) + d = _merge_params(s_list) + return [ + ParamsDependency(stage, path, params) for path, params in d.items() + ] diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 9add46457d..4df6087f52 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -40,7 +40,7 @@ def __init__(self, stage, path, params): info=info, ) - def _dyn_load(self, values=None): + def fill_values(self, values=None): """Load params values dynamically.""" if not values: return diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index 492b0e4862..74a550a8f5 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -1,5 +1,7 @@ +from collections import defaultdict from urllib.parse import urlparse +from funcy import collecting, project from voluptuous import And, Any, Coerce, Length, Lower, Required, SetTo from dvc.output.base import BaseOutput @@ -58,7 +60,9 @@ SCHEMA[BaseOutput.PARAM_PERSIST] = bool -def _get(stage, p, info, cache, metric, plot=False, persist=False): +def _get( + stage, p, info=None, cache=True, metric=False, plot=False, persist=False +): parsed = urlparse(p) if parsed.scheme == "remote": @@ -135,3 +139,45 @@ def loads_from( ) for s in s_list ] + + +def _split_dict(d, keys): + return project(d, keys), project(d, d.keys() - keys) + + +def _merge_data(s_list): + d = defaultdict(dict) + for key in s_list: + if isinstance(key, str): + d[key].update({}) + continue + if not isinstance(key, dict): + raise ValueError(f"'{type(key).__name__}' not supported.") + + for k, flags in key.items(): + if not isinstance(flags, dict): + raise ValueError( + f"Expected dict for '{k}', got: '{type(flags).__name__}'" + ) + d[k].update(flags) + return d + + +@collecting +def load_from_pipeline(stage, s_list, typ="outs"): + if typ not in (stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS): + raise ValueError(f"'{typ}' key is not allowed for pipeline files.") + + metric = typ == stage.PARAM_METRICS + plot = typ == stage.PARAM_PLOTS + + d = _merge_data(s_list) + + for path, flags in d.items(): + plt_d = {} + if plot: + from dvc.schema import PLOT_PROPS + + plt_d, flags = _split_dict(flags, keys=PLOT_PROPS.keys()) + extra = project(flags, ["cache", "persist"]) + yield _get(stage, path, {}, plot=plt_d or plot, metric=metric, **extra) diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 313d00feaa..36c52726ca 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -1,6 +1,6 @@ import os -from funcy import concat, first +from funcy import concat, first, lfilter from dvc.exceptions import InvalidArgumentError from dvc.stage.exceptions import ( @@ -20,6 +20,19 @@ def is_valid_name(name: str): return not INVALID_STAGENAME_CHARS & set(name) +def parse_params(path_params): + ret = [] + for path_param in path_params: + path, _, params_str = path_param.rpartition(":") + # remove empty strings from params, on condition such as `-p "file1:"` + params = lfilter(bool, params_str.split(",")) + if not path: + ret.extend(params) + else: + ret.append({path: params}) + return ret + + def _get_file_path(kwargs): from dvc.dvcfile import DVC_FILE_SUFFIX, DVC_FILE @@ -72,7 +85,10 @@ def run(self, fname=None, no_exec=False, single_stage=False, **kwargs): if not is_valid_name(stage_name): raise InvalidStageName - stage = create_stage(stage_cls, repo=self, path=path, **kwargs) + params = parse_params(kwargs.pop("params", [])) + stage = create_stage( + stage_cls, repo=self, path=path, params=params, **kwargs + ) if stage is None: return None diff --git a/dvc/stage/loader.py b/dvc/stage/loader.py index 169d5d3039..40a2d2d89d 100644 --- a/dvc/stage/loader.py +++ b/dvc/stage/loader.py @@ -1,15 +1,15 @@ import logging import os -from collections import defaultdict from collections.abc import Mapping from copy import deepcopy from itertools import chain -from funcy import first +from funcy import lcat, project from dvc import dependency, output from ..dependency import ParamsDependency +from . import fill_stage_dependencies from .exceptions import StageNameUnspecified, StageNotFound logger = logging.getLogger(__name__) @@ -32,6 +32,7 @@ def __init__(self, dvcfile, stages_data, lockfile_data=None): @staticmethod def fill_from_lock(stage, lock_data): + """Fill values for params, checksums for outs and deps from lock.""" from .params import StageParams items = chain( @@ -46,8 +47,8 @@ def fill_from_lock(stage, lock_data): for key, item in items: if isinstance(item, ParamsDependency): # load the params with values inside lock dynamically - params = lock_data.get("params", {}).get(item.def_path, {}) - item._dyn_load(params) + lock_params = lock_data.get(stage.PARAM_PARAMS, {}) + item.fill_values(lock_params.get(item.def_path, {})) continue item.checksum = ( @@ -56,104 +57,6 @@ def fill_from_lock(stage, lock_data): .get(item.checksum_type) ) - @classmethod - def _load_params(cls, stage, pipeline_params): - """ - File in pipeline file is expected to be in following format: - ``` - params: - - lr - - train.epochs - - params2.yaml: # notice the filename - - process.threshold - - process.bow - ``` - - and, in lockfile, we keep it as following format: - ``` - params: - params.yaml: - lr: 0.0041 - train.epochs: 100 - params2.yaml: - process.threshold: 0.98 - process.bow: - - 15000 - - 123 - ``` - In the list of `params` inside pipeline file, if any of the item is - dict-like, the key will be treated as separate params file and it's - values to be part of that params file, else, the item is considered - as part of the `params.yaml` which is a default file. - - (From example above: `lr` is considered to be part of `params.yaml` - whereas `process.bow` to be part of `params2.yaml`.) - - We only load the keys here, lockfile bears the values which are used - to compare between the actual params from the file in the workspace. - """ - res = defaultdict(list) - for key in pipeline_params: - if isinstance(key, str): - path = DEFAULT_PARAMS_FILE - res[path].append(key) - elif isinstance(key, dict): - path = first(key) - res[path].extend(key[path]) - - stage.deps.extend( - dependency.loadd_from( - stage, - [ - {"path": key, "params": params} - for key, params in res.items() - ], - ) - ) - - @classmethod - def _load_outs(cls, stage, data, typ=None): - from dvc.output.base import BaseOutput - - d = [] - for key in data: - if isinstance(key, str): - entry = {BaseOutput.PARAM_PATH: key} - if typ: - entry[typ] = True - d.append(entry) - continue - - assert isinstance(key, dict) - assert len(key) == 1 - - path = first(key) - extra = key[path] - - if not typ: - d.append({BaseOutput.PARAM_PATH: path, **extra}) - continue - - entry = {BaseOutput.PARAM_PATH: path} - - persist = extra.pop(BaseOutput.PARAM_PERSIST, False) - if persist: - entry[BaseOutput.PARAM_PERSIST] = persist - - cache = extra.pop(BaseOutput.PARAM_CACHE, True) - if not cache: - entry[BaseOutput.PARAM_CACHE] = cache - - entry[typ] = extra or True - - d.append(entry) - - stage.outs.extend(output.loadd_from(stage, d)) - - @classmethod - def _load_deps(cls, stage, data): - stage.deps.extend(dependency.loads_from(stage, data)) - @classmethod def load_stage(cls, dvcfile, name, stage_data, lock_data): from . import PipelineStage, Stage, loads_from @@ -163,13 +66,18 @@ def load_stage(cls, dvcfile, name, stage_data, lock_data): ) stage = loads_from(PipelineStage, dvcfile.repo, path, wdir, stage_data) stage.name = name - stage.deps, stage.outs = [], [] - cls._load_outs(stage, stage_data.get("outs", [])) - cls._load_outs(stage, stage_data.get("metrics", []), "metric") - cls._load_outs(stage, stage_data.get("plots", []), "plot") - cls._load_deps(stage, stage_data.get("deps", [])) - cls._load_params(stage, stage_data.get("params", [])) + deps = project(stage_data, [stage.PARAM_DEPS, stage.PARAM_PARAMS]) + fill_stage_dependencies(stage, **deps) + + outs = project( + stage_data, + [stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS], + ) + stage.outs = lcat( + output.load_from_pipeline(stage, data, typ=key) + for key, data in outs.items() + ) if lock_data: stage.cmd_changed = lock_data.get( diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index 8728463411..b4edf53d6e 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -15,19 +15,40 @@ def test_loads_params(dvc): stage = Stage(dvc) - deps = loads_params(stage, ["foo", "bar,baz", "a_file:qux"]) - assert len(deps) == 2 + deps = loads_params( + stage, + [ + "foo", + "bar", + {"a_file": ["baz", "bat"]}, + {"b_file": ["cat"]}, + {}, + {"a_file": ["foobar"]}, + ], + ) + assert len(deps) == 3 assert isinstance(deps[0], ParamsDependency) assert deps[0].def_path == ParamsDependency.DEFAULT_PARAMS_FILE - assert deps[0].params == ["foo", "bar", "baz"] + assert deps[0].params == ["foo", "bar"] assert deps[0].info == {} assert isinstance(deps[1], ParamsDependency) assert deps[1].def_path == "a_file" - assert deps[1].params == ["qux"] + assert deps[1].params == ["baz", "bat", "foobar"] assert deps[1].info == {} + assert isinstance(deps[2], ParamsDependency) + assert deps[2].def_path == "b_file" + assert deps[2].params == ["cat"] + assert deps[2].info == {} + + +@pytest.mark.parametrize("params", [[3], [{"b_file": "cat"}]]) +def test_params_error(dvc, params): + with pytest.raises(ValueError): + loads_params(Stage(dvc), params) + def test_loadd_from(dvc): stage = Stage(dvc) diff --git a/tests/unit/output/test_load.py b/tests/unit/output/test_load.py new file mode 100644 index 0000000000..68171a09a9 --- /dev/null +++ b/tests/unit/output/test_load.py @@ -0,0 +1,114 @@ +import pytest + +from dvc import output +from dvc.output import LocalOutput, S3Output +from dvc.stage import Stage + + +@pytest.mark.parametrize( + "out_type,type_test_func", + [ + ("outs", lambda o: not (o.metric or o.plot)), + ("metrics", lambda o: o.metric and not o.plot), + ("plots", lambda o: o.plot and not o.metric), + ], + ids=("outs", "metrics", "plots"), +) +def test_load_from_pipeline(dvc, out_type, type_test_func): + outs = output.load_from_pipeline( + Stage(dvc), + [ + "file1", + "file2", + {"file3": {"cache": True}}, + {}, + {"file4": {"cache": False}}, + {"file5": {"persist": False}}, + {"file6": {"persist": True, "cache": False}}, + ], + out_type, + ) + cached_outs = {"file1", "file2", "file3", "file5"} + persisted_outs = {"file6"} + assert len(outs) == 6 + + for i, out in enumerate(outs, start=1): + assert isinstance(out, LocalOutput) + assert out.def_path == f"file{i}" + assert out.use_cache == (out.def_path in cached_outs) + assert out.persist == (out.def_path in persisted_outs) + assert out.info == {} + assert type_test_func(out) + + +def test_load_from_pipeline_accumulates_flag(dvc): + outs = output.load_from_pipeline( + Stage(dvc), + [ + "file1", + {"file2": {"cache": False}}, + {"file1": {"persist": False}}, + {"file2": {"persist": True}}, + ], + "outs", + ) + for out in outs: + assert isinstance(out, LocalOutput) + assert not out.plot and not out.metric + assert out.info == {} + + assert outs[0].use_cache and not outs[0].persist + assert not outs[1].use_cache and outs[1].persist + + +def test_load_remote_files_from_pipeline(dvc): + stage = Stage(dvc) + (out,) = output.load_from_pipeline( + stage, [{"s3://dvc-test/file.txt": {"cache": False}}], typ="metrics" + ) + assert isinstance(out, S3Output) + assert not out.plot and out.metric + assert not out.persist + assert out.info == {} + + +@pytest.mark.parametrize("typ", [None, "", "illegal"]) +def test_load_from_pipeline_error_on_typ(dvc, typ): + with pytest.raises(ValueError): + output.load_from_pipeline(Stage(dvc), ["file1"], typ) + + +@pytest.mark.parametrize("key", [3, "list".split()]) +def test_load_from_pipeline_illegal_type(dvc, key): + stage = Stage(dvc) + with pytest.raises(ValueError): + output.load_from_pipeline(stage, [key], "outs") + with pytest.raises(ValueError): + output.load_from_pipeline(stage, [{"key": key}], "outs") + + +def test_plots_load_from_pipeline(dvc): + outs = output.load_from_pipeline( + Stage(dvc), + [ + "file1", + { + "file2": { + "persist": True, + "cache": False, + "x": 3, + "random": "val", + } + }, + ], + "plots", + ) + assert isinstance(outs[0], LocalOutput) + assert outs[0].use_cache + assert outs[0].plot is True and not outs[0].metric + assert not outs[0].persist + + assert isinstance(outs[1], LocalOutput) + assert not outs[1].use_cache + assert outs[1].plot == {"x": 3} and not outs[1].metric + assert outs[1].persist diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py new file mode 100644 index 0000000000..425389f7c9 --- /dev/null +++ b/tests/unit/test_run.py @@ -0,0 +1,36 @@ +import pytest + +from dvc.repo.run import is_valid_name, parse_params + + +def test_parse_params(): + assert parse_params( + [ + "param1", + "file1:param1,param2", + "file2:param2", + "file1:param2,param3,", + "param1,param2", + "param3,", + "file3:", + ] + ) == [ + "param1", + {"file1": ["param1", "param2"]}, + {"file2": ["param2"]}, + {"file1": ["param2", "param3"]}, + "param1", + "param2", + "param3", + {"file3": []}, + ] + + +@pytest.mark.parametrize("name", ["copy_name", "copy-name", "copyName", "12"]) +def test_valid_stage_names(name): + assert is_valid_name(name) + + +@pytest.mark.parametrize("name", ["copy$name", "copy-name?", "copy-name@v1"]) +def test_invalid_stage_names(name): + assert not is_valid_name(name)