From d4e780699eba5c44587d62858a00d469e22981fa Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Wed, 3 Jun 2020 15:45:48 +0545 Subject: [PATCH 01/10] refactor multistage load for params and outputs --- dvc/dependency/__init__.py | 41 ++++++------ dvc/dependency/param.py | 2 +- dvc/output/__init__.py | 33 +++++++++- dvc/repo/run.py | 17 ++++- dvc/stage/loader.py | 124 +++++-------------------------------- 5 files changed, 84 insertions(+), 133 deletions(-) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 1fa2073e24..2cf449d036 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -1,6 +1,8 @@ from collections import defaultdict from urllib.parse import urlparse +from funcy import first + import dvc.output as output from dvc.dependency.azure import AzureDependency from dvc.dependency.gs import GSDependency @@ -87,26 +89,21 @@ def loads_from(stage, s_list, erepo=None): return ret -def _parse_params(path_params): - path, _, params_str = path_params.rpartition(":") - params = params_str.split(",") - return path, params - - def loads_params(stage, s_list): - # Creates an object for each unique file that is referenced in the list - params_by_path = defaultdict(list) - for s in s_list: - path, params = _parse_params(s) - params_by_path[path].extend(params) - - d_list = [] - for path, params in params_by_path.items(): - d_list.append( - { - BaseOutput.PARAM_PATH: path, - ParamsDependency.PARAM_PARAMS: params, - } - ) - - return loadd_from(stage, d_list) + d = defaultdict(list) + for key in s_list: + if isinstance(key, str): + path = ParamsDependency.DEFAULT_PARAMS_FILE + params = [key] + else: + assert isinstance(key, dict) + path = first(key) + if not path: + continue + params = key[path] + assert isinstance(params, list) + d[path].extend(params) + + return [ + ParamsDependency(stage, path, params) for path, params in d.items() + ] diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 9add46457d..78d0bf712f 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -40,7 +40,7 @@ def __init__(self, stage, path, params): info=info, ) - def _dyn_load(self, values=None): + def inject_values(self, values=None): """Load params values dynamically.""" if not values: return diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index 492b0e4862..f30640807b 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -1,5 +1,7 @@ +from collections import defaultdict from urllib.parse import urlparse +from funcy import first from voluptuous import And, Any, Coerce, Length, Lower, Required, SetTo from dvc.output.base import BaseOutput @@ -58,7 +60,9 @@ SCHEMA[BaseOutput.PARAM_PERSIST] = bool -def _get(stage, p, info, cache, metric, plot=False, persist=False): +def _get( + stage, p, info=None, cache=True, metric=False, plot=False, persist=False +): parsed = urlparse(p) if parsed.scheme == "remote": @@ -135,3 +139,30 @@ def loads_from( ) for s in s_list ] + + +def load_from_pipeline(stage, s_list, typ): + out_types = { + stage.PARAM_OUTS: {}, + stage.PARAM_METRICS: {"metric": True}, + stage.PARAM_PARAMS: {"param": True}, + } + extra = out_types[typ] + d = defaultdict(dict) + for key in s_list: + flags = {} + if isinstance(key, str): + path = key + else: + assert isinstance(key, dict) + path = first(key) + if not path: + continue + flags = key[path] + assert isinstance(flags, dict) + d[path].update(flags) + + return [ + _get(stage, path, info={}, **flags, **extra) + for path, flags in d.items() + ] diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 313d00feaa..635e96764c 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -20,6 +20,18 @@ def is_valid_name(name: str): return not INVALID_STAGENAME_CHARS & set(name) +def parse_params(path_params): + ret = [] + for path_param in path_params: + path, _, params_str = path_param.rpartition(":") + params = params_str.split(",") + if not path: + ret.extend(params) + else: + ret.append({path: params}) + return ret + + def _get_file_path(kwargs): from dvc.dvcfile import DVC_FILE_SUFFIX, DVC_FILE @@ -72,7 +84,10 @@ def run(self, fname=None, no_exec=False, single_stage=False, **kwargs): if not is_valid_name(stage_name): raise InvalidStageName - stage = create_stage(stage_cls, repo=self, path=path, **kwargs) + params = parse_params(kwargs.pop("params", [])) + stage = create_stage( + stage_cls, repo=self, path=path, params=params, **kwargs + ) if stage is None: return None diff --git a/dvc/stage/loader.py b/dvc/stage/loader.py index 169d5d3039..09aa8da911 100644 --- a/dvc/stage/loader.py +++ b/dvc/stage/loader.py @@ -1,15 +1,15 @@ import logging import os -from collections import defaultdict from collections.abc import Mapping from copy import deepcopy from itertools import chain -from funcy import first +from funcy import lcat, project from dvc import dependency, output from ..dependency import ParamsDependency +from . import fill_stage_dependencies from .exceptions import StageNameUnspecified, StageNotFound logger = logging.getLogger(__name__) @@ -32,6 +32,7 @@ def __init__(self, dvcfile, stages_data, lockfile_data=None): @staticmethod def fill_from_lock(stage, lock_data): + """Fill values for params, checksums for outs and deps from lock.""" from .params import StageParams items = chain( @@ -46,8 +47,8 @@ def fill_from_lock(stage, lock_data): for key, item in items: if isinstance(item, ParamsDependency): # load the params with values inside lock dynamically - params = lock_data.get("params", {}).get(item.def_path, {}) - item._dyn_load(params) + lock_params = lock_data.get(stage.PARAM_PARAMS, {}) + item.inject_values(lock_params.get(item.def_path, {})) continue item.checksum = ( @@ -56,104 +57,6 @@ def fill_from_lock(stage, lock_data): .get(item.checksum_type) ) - @classmethod - def _load_params(cls, stage, pipeline_params): - """ - File in pipeline file is expected to be in following format: - ``` - params: - - lr - - train.epochs - - params2.yaml: # notice the filename - - process.threshold - - process.bow - ``` - - and, in lockfile, we keep it as following format: - ``` - params: - params.yaml: - lr: 0.0041 - train.epochs: 100 - params2.yaml: - process.threshold: 0.98 - process.bow: - - 15000 - - 123 - ``` - In the list of `params` inside pipeline file, if any of the item is - dict-like, the key will be treated as separate params file and it's - values to be part of that params file, else, the item is considered - as part of the `params.yaml` which is a default file. - - (From example above: `lr` is considered to be part of `params.yaml` - whereas `process.bow` to be part of `params2.yaml`.) - - We only load the keys here, lockfile bears the values which are used - to compare between the actual params from the file in the workspace. - """ - res = defaultdict(list) - for key in pipeline_params: - if isinstance(key, str): - path = DEFAULT_PARAMS_FILE - res[path].append(key) - elif isinstance(key, dict): - path = first(key) - res[path].extend(key[path]) - - stage.deps.extend( - dependency.loadd_from( - stage, - [ - {"path": key, "params": params} - for key, params in res.items() - ], - ) - ) - - @classmethod - def _load_outs(cls, stage, data, typ=None): - from dvc.output.base import BaseOutput - - d = [] - for key in data: - if isinstance(key, str): - entry = {BaseOutput.PARAM_PATH: key} - if typ: - entry[typ] = True - d.append(entry) - continue - - assert isinstance(key, dict) - assert len(key) == 1 - - path = first(key) - extra = key[path] - - if not typ: - d.append({BaseOutput.PARAM_PATH: path, **extra}) - continue - - entry = {BaseOutput.PARAM_PATH: path} - - persist = extra.pop(BaseOutput.PARAM_PERSIST, False) - if persist: - entry[BaseOutput.PARAM_PERSIST] = persist - - cache = extra.pop(BaseOutput.PARAM_CACHE, True) - if not cache: - entry[BaseOutput.PARAM_CACHE] = cache - - entry[typ] = extra or True - - d.append(entry) - - stage.outs.extend(output.loadd_from(stage, d)) - - @classmethod - def _load_deps(cls, stage, data): - stage.deps.extend(dependency.loads_from(stage, data)) - @classmethod def load_stage(cls, dvcfile, name, stage_data, lock_data): from . import PipelineStage, Stage, loads_from @@ -163,13 +66,18 @@ def load_stage(cls, dvcfile, name, stage_data, lock_data): ) stage = loads_from(PipelineStage, dvcfile.repo, path, wdir, stage_data) stage.name = name - stage.deps, stage.outs = [], [] - cls._load_outs(stage, stage_data.get("outs", [])) - cls._load_outs(stage, stage_data.get("metrics", []), "metric") - cls._load_outs(stage, stage_data.get("plots", []), "plot") - cls._load_deps(stage, stage_data.get("deps", [])) - cls._load_params(stage, stage_data.get("params", [])) + deps = project(stage_data, [stage.PARAM_DEPS, stage.PARAM_PARAMS]) + fill_stage_dependencies(stage, **deps) + + outs = project( + stage_data, + [stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS], + ) + stage.outs = lcat( + output.load_from_pipeline(stage, data, typ=key) + for key, data in outs.items() + ) if lock_data: stage.cmd_changed = lock_data.get( From f9a6687a3dc90c1d03d17185bfe6975d9cdd8c27 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Wed, 3 Jun 2020 17:03:42 +0545 Subject: [PATCH 02/10] tests: load params --- dvc/dependency/__init__.py | 10 ++++++-- dvc/repo/run.py | 5 ++-- tests/unit/dependency/test_params.py | 29 ++++++++++++++++++---- tests/unit/test_run.py | 36 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 tests/unit/test_run.py diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 2cf449d036..ab94fcb556 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -96,12 +96,18 @@ def loads_params(stage, s_list): path = ParamsDependency.DEFAULT_PARAMS_FILE params = [key] else: - assert isinstance(key, dict) + if not isinstance(key, dict): + msg = "Only list of str/dict is supported. Got: " + msg += f"'{type(key).__name__}'." + raise ValueError(msg) path = first(key) if not path: continue params = key[path] - assert isinstance(params, list) + if not isinstance(params, list): + msg = "Expected list of params for custom params file " + msg += f"'{path}', got '{type(params).__name__}'." + raise ValueError(msg) d[path].extend(params) return [ diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 635e96764c..36c52726ca 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -1,6 +1,6 @@ import os -from funcy import concat, first +from funcy import concat, first, lfilter from dvc.exceptions import InvalidArgumentError from dvc.stage.exceptions import ( @@ -24,7 +24,8 @@ def parse_params(path_params): ret = [] for path_param in path_params: path, _, params_str = path_param.rpartition(":") - params = params_str.split(",") + # remove empty strings from params, on condition such as `-p "file1:"` + params = lfilter(bool, params_str.split(",")) if not path: ret.extend(params) else: diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index 8728463411..b4edf53d6e 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -15,19 +15,40 @@ def test_loads_params(dvc): stage = Stage(dvc) - deps = loads_params(stage, ["foo", "bar,baz", "a_file:qux"]) - assert len(deps) == 2 + deps = loads_params( + stage, + [ + "foo", + "bar", + {"a_file": ["baz", "bat"]}, + {"b_file": ["cat"]}, + {}, + {"a_file": ["foobar"]}, + ], + ) + assert len(deps) == 3 assert isinstance(deps[0], ParamsDependency) assert deps[0].def_path == ParamsDependency.DEFAULT_PARAMS_FILE - assert deps[0].params == ["foo", "bar", "baz"] + assert deps[0].params == ["foo", "bar"] assert deps[0].info == {} assert isinstance(deps[1], ParamsDependency) assert deps[1].def_path == "a_file" - assert deps[1].params == ["qux"] + assert deps[1].params == ["baz", "bat", "foobar"] assert deps[1].info == {} + assert isinstance(deps[2], ParamsDependency) + assert deps[2].def_path == "b_file" + assert deps[2].params == ["cat"] + assert deps[2].info == {} + + +@pytest.mark.parametrize("params", [[3], [{"b_file": "cat"}]]) +def test_params_error(dvc, params): + with pytest.raises(ValueError): + loads_params(Stage(dvc), params) + def test_loadd_from(dvc): stage = Stage(dvc) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py new file mode 100644 index 0000000000..ae910546d1 --- /dev/null +++ b/tests/unit/test_run.py @@ -0,0 +1,36 @@ +import pytest + +from dvc.repo.run import is_valid_name, parse_params + + +def test_parse_params(): + assert parse_params( + [ + "param1", + "file1:param1,param2", + "file2:param2", + "file1:param2,param3,", + "param1,param2", + "param3,", + "file3:", + ] + ) == [ + "param1", + {"file1": ["param1", "param2"]}, + {"file2": ["param2"]}, + {"file1": ["param2", "param3"]}, + "param1", + "param2", + "param3", + {"file3": []}, + ] + + +@pytest.mark.parametrize("name", ["copy_name", "copy-name", "copy-name", "12"]) +def test_valid_stage_names(name): + assert is_valid_name(name) + + +@pytest.mark.parametrize("name", ["copy$name", "copy-name?", "copy-name@v1"]) +def test_invalid_stage_names(name): + assert not is_valid_name(name) From 6a0028e96dde7bf9fb7e7b1df2642c342c3a95af Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Wed, 3 Jun 2020 21:38:42 +0545 Subject: [PATCH 03/10] tests: output loading from pipeline file --- dvc/output/__init__.py | 54 ++++++++++++---- tests/unit/output/test_load.py | 115 +++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+), 14 deletions(-) create mode 100644 tests/unit/output/test_load.py diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index f30640807b..b578e45610 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -1,7 +1,7 @@ from collections import defaultdict from urllib.parse import urlparse -from funcy import first +from funcy import collecting, first, project from voluptuous import And, Any, Coerce, Length, Lower, Required, SetTo from dvc.output.base import BaseOutput @@ -141,28 +141,54 @@ def loads_from( ] -def load_from_pipeline(stage, s_list, typ): - out_types = { - stage.PARAM_OUTS: {}, - stage.PARAM_METRICS: {"metric": True}, - stage.PARAM_PARAMS: {"param": True}, - } - extra = out_types[typ] +def _split_plot_data_and_flags(flags): + from dvc.schema import PLOT_PROPS + + plot_data = project(flags, PLOT_PROPS) + flags = project(flags, flags.keys() - PLOT_PROPS.keys()) + return plot_data if plot_data else True, flags + + +def _merge_data(s_list): d = defaultdict(dict) for key in s_list: flags = {} if isinstance(key, str): path = key else: - assert isinstance(key, dict) + if not isinstance(key, dict): + raise ValueError(f"'{type(key).__name__}' not supported.") path = first(key) if not path: continue flags = key[path] - assert isinstance(flags, dict) + if not isinstance(flags, dict): + raise ValueError( + f"Expected dict for '{path}', got: '{type(key).__name__}'" + ) d[path].update(flags) + return d - return [ - _get(stage, path, info={}, **flags, **extra) - for path, flags in d.items() - ] + +@collecting +def load_from_pipeline(stage, s_list, typ="outs"): + out_types = { + stage.PARAM_OUTS: {}, + stage.PARAM_METRICS: {BaseOutput.PARAM_METRIC: True}, + stage.PARAM_PLOTS: {BaseOutput.PARAM_PLOT: True}, + } + if typ not in out_types: + raise ValueError(f"'{typ}' key is not allowed for pipeline files.") + + d = _merge_data(s_list) + for path, flags in d.items(): + if typ == stage.PARAM_PLOTS: + plot_data, flags = _split_plot_data_and_flags(flags) + out_types[typ][BaseOutput.PARAM_PLOT] = plot_data + yield _get( + stage, + path, + info={}, + **project(flags, ["cache", "persist"]), + **out_types[typ], + ) diff --git a/tests/unit/output/test_load.py b/tests/unit/output/test_load.py new file mode 100644 index 0000000000..729c951a2e --- /dev/null +++ b/tests/unit/output/test_load.py @@ -0,0 +1,115 @@ +import pytest + +from dvc import output +from dvc.output import LocalOutput, S3Output +from dvc.stage import Stage + + +@pytest.mark.parametrize( + "out_type,type_test_func", + [ + ("outs", lambda o: not (o.metric or o.plot)), + ("metrics", lambda o: o.metric and not o.plot), + ("plots", lambda o: o.plot and not o.metric), + ], + ids=("outs", "metrics", "plots"), +) +def test_load_from_pipeline(dvc, out_type, type_test_func): + outs = output.load_from_pipeline( + Stage(dvc), + [ + "file1", + "file2", + {"file3": {"cache": True}}, + {}, + {"file4": {"cache": False}}, + {"file5": {"persist": False}}, + {"file6": {"persist": True, "cache": False}}, + ], + out_type, + ) + cached_outs = {"file1", "file2", "file3", "file5"} + persisted_outs = {"file6"} + assert len(outs) == 6 + + for i, out in enumerate(outs, start=1): + assert isinstance(out, LocalOutput) + assert out.def_path == f"file{i}" + assert out.use_cache == (out.def_path in cached_outs) + assert out.persist == (out.def_path in persisted_outs) + assert out.info == {} + assert type_test_func(out) + + +def test_load_from_pipeline_accumulates_flag(dvc): + outs = output.load_from_pipeline( + Stage(dvc), + [ + "file1", + {"file2": {"cache": False}}, + {"file1": {"persist": False}}, + {"file2": {"persist": False}}, + ], + "outs", + ) + for out in outs: + assert isinstance(out, LocalOutput) + assert not out.plot and not out.metric + assert not out.persist + assert out.info == {} + + assert outs[0].use_cache + assert not outs[1].use_cache + + +def test_load_remote_files_from_pipeline(dvc): + stage = Stage(dvc) + (out,) = output.load_from_pipeline( + stage, [{"s3://dvc-test/file.txt": {"cache": False}}], typ="metrics" + ) + assert isinstance(out, S3Output) + assert not out.plot and out.metric + assert not out.persist + assert out.info == {} + + +@pytest.mark.parametrize("typ", [None, "", "illegal"]) +def test_load_from_pipeline_error_on_typ(dvc, typ): + with pytest.raises(ValueError): + output.load_from_pipeline(Stage(dvc), typ, None) + + +@pytest.mark.parametrize("key", [3, "list".split()]) +def test_load_from_pipeline_illegal_type(dvc, key): + stage = Stage(dvc) + with pytest.raises(ValueError): + output.load_from_pipeline(stage, [key], "outs") + with pytest.raises(ValueError): + output.load_from_pipeline(stage, [{"key": key}], "outs") + + +def test_params_load_from_pipeline(dvc): + outs = output.load_from_pipeline( + Stage(dvc), + [ + "file1", + { + "file2": { + "persist": True, + "cache": False, + "x": 3, + "random": "val", + } + }, + ], + "plots", + ) + assert isinstance(outs[0], LocalOutput) + assert outs[0].use_cache + assert outs[0].plot is True and not outs[0].metric + assert not outs[0].persist + + assert isinstance(outs[0], LocalOutput) + assert not outs[1].use_cache + assert outs[1].plot == {"x": 3} and not outs[0].metric + assert outs[1].persist From 45cb2610ca14dab7c1b59607a90f7e26de97a3f2 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Wed, 3 Jun 2020 21:58:56 +0545 Subject: [PATCH 04/10] fix test --- tests/unit/output/test_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/output/test_load.py b/tests/unit/output/test_load.py index 729c951a2e..4869eb3fa2 100644 --- a/tests/unit/output/test_load.py +++ b/tests/unit/output/test_load.py @@ -76,7 +76,7 @@ def test_load_remote_files_from_pipeline(dvc): @pytest.mark.parametrize("typ", [None, "", "illegal"]) def test_load_from_pipeline_error_on_typ(dvc, typ): with pytest.raises(ValueError): - output.load_from_pipeline(Stage(dvc), typ, None) + output.load_from_pipeline(Stage(dvc), ["file1"], typ) @pytest.mark.parametrize("key", [3, "list".split()]) From 40ae3c7c825b8cb0aa6831e165ef9d55decca36e Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Wed, 3 Jun 2020 22:08:34 +0545 Subject: [PATCH 05/10] fix typo in name --- tests/unit/output/test_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/output/test_load.py b/tests/unit/output/test_load.py index 4869eb3fa2..cb71666337 100644 --- a/tests/unit/output/test_load.py +++ b/tests/unit/output/test_load.py @@ -88,7 +88,7 @@ def test_load_from_pipeline_illegal_type(dvc, key): output.load_from_pipeline(stage, [{"key": key}], "outs") -def test_params_load_from_pipeline(dvc): +def test_plots_load_from_pipeline(dvc): outs = output.load_from_pipeline( Stage(dvc), [ From 44cb8378543afe1bc03f32e2e8b8675caa318d8d Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Wed, 3 Jun 2020 22:15:17 +0545 Subject: [PATCH 06/10] split params load --- dvc/dependency/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index ab94fcb556..0856896423 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -89,7 +89,7 @@ def loads_from(stage, s_list, erepo=None): return ret -def loads_params(stage, s_list): +def _merge_params(s_list): d = defaultdict(list) for key in s_list: if isinstance(key, str): @@ -109,7 +109,11 @@ def loads_params(stage, s_list): msg += f"'{path}', got '{type(params).__name__}'." raise ValueError(msg) d[path].extend(params) + return d + +def loads_params(stage, s_list): + d = _merge_params(s_list) return [ ParamsDependency(stage, path, params) for path, params in d.items() ] From 3a0a9b84cc5e92e1f16005003a77596d7a60616c Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 4 Jun 2020 09:28:42 +0545 Subject: [PATCH 07/10] rename params func s/inject_values/fill_values --- dvc/dependency/param.py | 2 +- dvc/stage/loader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 78d0bf712f..4df6087f52 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -40,7 +40,7 @@ def __init__(self, stage, path, params): info=info, ) - def inject_values(self, values=None): + def fill_values(self, values=None): """Load params values dynamically.""" if not values: return diff --git a/dvc/stage/loader.py b/dvc/stage/loader.py index 09aa8da911..40a2d2d89d 100644 --- a/dvc/stage/loader.py +++ b/dvc/stage/loader.py @@ -48,7 +48,7 @@ def fill_from_lock(stage, lock_data): if isinstance(item, ParamsDependency): # load the params with values inside lock dynamically lock_params = lock_data.get(stage.PARAM_PARAMS, {}) - item.inject_values(lock_params.get(item.def_path, {})) + item.fill_values(lock_params.get(item.def_path, {})) continue item.checksum = ( From 54a450e6f0a8529810068095af25157320f0a80d Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 4 Jun 2020 09:32:41 +0545 Subject: [PATCH 08/10] fix tests --- tests/unit/output/test_load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/output/test_load.py b/tests/unit/output/test_load.py index cb71666337..50b0a82d89 100644 --- a/tests/unit/output/test_load.py +++ b/tests/unit/output/test_load.py @@ -109,7 +109,7 @@ def test_plots_load_from_pipeline(dvc): assert outs[0].plot is True and not outs[0].metric assert not outs[0].persist - assert isinstance(outs[0], LocalOutput) + assert isinstance(outs[1], LocalOutput) assert not outs[1].use_cache - assert outs[1].plot == {"x": 3} and not outs[0].metric + assert outs[1].plot == {"x": 3} and not outs[1].metric assert outs[1].persist From 339d833432cd937fb5f97f277d30dc7a39d8dc46 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 4 Jun 2020 22:32:45 +0545 Subject: [PATCH 09/10] simplify loads_params and output.load_from_pipeline --- dvc/dependency/__init__.py | 26 +++++++--------- dvc/output/__init__.py | 63 ++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 52 deletions(-) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 0856896423..10a2f787a6 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -1,8 +1,6 @@ from collections import defaultdict from urllib.parse import urlparse -from funcy import first - import dvc.output as output from dvc.dependency.azure import AzureDependency from dvc.dependency.gs import GSDependency @@ -91,24 +89,22 @@ def loads_from(stage, s_list, erepo=None): def _merge_params(s_list): d = defaultdict(list) + default_file = ParamsDependency.DEFAULT_PARAMS_FILE for key in s_list: if isinstance(key, str): - path = ParamsDependency.DEFAULT_PARAMS_FILE - params = [key] - else: - if not isinstance(key, dict): - msg = "Only list of str/dict is supported. Got: " - msg += f"'{type(key).__name__}'." - raise ValueError(msg) - path = first(key) - if not path: - continue - params = key[path] + d[default_file].append(key) + continue + if not isinstance(key, dict): + msg = "Only list of str/dict is supported. Got: " + msg += f"'{type(key).__name__}'." + raise ValueError(msg) + + for k, params in key.items(): if not isinstance(params, list): msg = "Expected list of params for custom params file " - msg += f"'{path}', got '{type(params).__name__}'." + msg += f"'{k}', got '{type(params).__name__}'." raise ValueError(msg) - d[path].extend(params) + d[k].extend(params) return d diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index b578e45610..74a550a8f5 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -1,7 +1,7 @@ from collections import defaultdict from urllib.parse import urlparse -from funcy import collecting, first, project +from funcy import collecting, project from voluptuous import And, Any, Coerce, Length, Lower, Required, SetTo from dvc.output.base import BaseOutput @@ -141,54 +141,43 @@ def loads_from( ] -def _split_plot_data_and_flags(flags): - from dvc.schema import PLOT_PROPS - - plot_data = project(flags, PLOT_PROPS) - flags = project(flags, flags.keys() - PLOT_PROPS.keys()) - return plot_data if plot_data else True, flags +def _split_dict(d, keys): + return project(d, keys), project(d, d.keys() - keys) def _merge_data(s_list): d = defaultdict(dict) for key in s_list: - flags = {} if isinstance(key, str): - path = key - else: - if not isinstance(key, dict): - raise ValueError(f"'{type(key).__name__}' not supported.") - path = first(key) - if not path: - continue - flags = key[path] - if not isinstance(flags, dict): - raise ValueError( - f"Expected dict for '{path}', got: '{type(key).__name__}'" - ) - d[path].update(flags) + d[key].update({}) + continue + if not isinstance(key, dict): + raise ValueError(f"'{type(key).__name__}' not supported.") + + for k, flags in key.items(): + if not isinstance(flags, dict): + raise ValueError( + f"Expected dict for '{k}', got: '{type(flags).__name__}'" + ) + d[k].update(flags) return d @collecting def load_from_pipeline(stage, s_list, typ="outs"): - out_types = { - stage.PARAM_OUTS: {}, - stage.PARAM_METRICS: {BaseOutput.PARAM_METRIC: True}, - stage.PARAM_PLOTS: {BaseOutput.PARAM_PLOT: True}, - } - if typ not in out_types: + if typ not in (stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS): raise ValueError(f"'{typ}' key is not allowed for pipeline files.") + metric = typ == stage.PARAM_METRICS + plot = typ == stage.PARAM_PLOTS + d = _merge_data(s_list) + for path, flags in d.items(): - if typ == stage.PARAM_PLOTS: - plot_data, flags = _split_plot_data_and_flags(flags) - out_types[typ][BaseOutput.PARAM_PLOT] = plot_data - yield _get( - stage, - path, - info={}, - **project(flags, ["cache", "persist"]), - **out_types[typ], - ) + plt_d = {} + if plot: + from dvc.schema import PLOT_PROPS + + plt_d, flags = _split_dict(flags, keys=PLOT_PROPS.keys()) + extra = project(flags, ["cache", "persist"]) + yield _get(stage, path, {}, plot=plt_d or plot, metric=metric, **extra) From 0fdd944a868109ecdcc6d12f2587ab91d639b9ed Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Thu, 4 Jun 2020 22:40:56 +0545 Subject: [PATCH 10/10] address @pared's suggestions for tests --- tests/unit/output/test_load.py | 7 +++---- tests/unit/test_run.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/unit/output/test_load.py b/tests/unit/output/test_load.py index 50b0a82d89..68171a09a9 100644 --- a/tests/unit/output/test_load.py +++ b/tests/unit/output/test_load.py @@ -48,18 +48,17 @@ def test_load_from_pipeline_accumulates_flag(dvc): "file1", {"file2": {"cache": False}}, {"file1": {"persist": False}}, - {"file2": {"persist": False}}, + {"file2": {"persist": True}}, ], "outs", ) for out in outs: assert isinstance(out, LocalOutput) assert not out.plot and not out.metric - assert not out.persist assert out.info == {} - assert outs[0].use_cache - assert not outs[1].use_cache + assert outs[0].use_cache and not outs[0].persist + assert not outs[1].use_cache and outs[1].persist def test_load_remote_files_from_pipeline(dvc): diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index ae910546d1..425389f7c9 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -26,7 +26,7 @@ def test_parse_params(): ] -@pytest.mark.parametrize("name", ["copy_name", "copy-name", "copy-name", "12"]) +@pytest.mark.parametrize("name", ["copy_name", "copy-name", "copyName", "12"]) def test_valid_stage_names(name): assert is_valid_name(name)