diff --git a/dvc/command/run.py b/dvc/command/run.py index aa4f7ded5a..fdac98bcaa 100644 --- a/dvc/command/run.py +++ b/dvc/command/run.py @@ -22,13 +22,14 @@ def run(self): self.args.metrics_no_cache, self.args.outs_persist, self.args.outs_persist_no_cache, + self.args.params, self.args.command, ] ): # pragma: no cover logger.error( "too few arguments. Specify at least one: `-d`, `-o`, `-O`, " - "`-m`, `-M`, `--outs-persist`, `--outs-persist-no-cache`, " - "`command`." + "`-m`, `-M`, `-p`, `--outs-persist`, " + "`--outs-persist-no-cache`, `command`." ) return 1 @@ -40,6 +41,7 @@ def run(self): metrics=self.args.metrics, metrics_no_cache=self.args.metrics_no_cache, deps=self.args.deps, + params=self.args.params, fname=self.args.file, cwd=self.args.cwd, wdir=self.args.wdir, @@ -111,6 +113,13 @@ def add_parser(subparsers, parent_parser): help="Declare output file or directory " "(do not put into DVC cache).", ) + run_parser.add_argument( + "-p", + "--params", + action="append", + default=[], + help="Declare parameter to use as additional dependency.", + ) run_parser.add_argument( "-m", "--metrics", diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index a968f01443..f6346cdeec 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -1,4 +1,5 @@ from urllib.parse import urlparse +from collections import defaultdict import dvc.output as output from dvc.dependency.gs import DependencyGS @@ -8,6 +9,7 @@ from dvc.dependency.local import DependencyLOCAL from dvc.dependency.s3 import DependencyS3 from dvc.dependency.ssh import DependencySSH +from dvc.dependency.param import DependencyPARAMS from dvc.output.base import OutputBase from dvc.remote import Remote from dvc.scheme import Schemes @@ -42,13 +44,13 @@ SCHEMA = output.SCHEMA.copy() del SCHEMA[OutputBase.PARAM_CACHE] del SCHEMA[OutputBase.PARAM_METRIC] -SCHEMA[DependencyREPO.PARAM_REPO] = DependencyREPO.REPO_SCHEMA +SCHEMA.update(DependencyREPO.REPO_SCHEMA) +SCHEMA.update(DependencyPARAMS.PARAM_SCHEMA) def _get(stage, p, info): - parsed = urlparse(p) - - if parsed.scheme == "remote": + parsed = urlparse(p) if p else None + if parsed and parsed.scheme == "remote": remote = Remote(stage.repo, name=parsed.netloc) return DEP_MAP[remote.scheme](stage, p, info, remote=remote) @@ -56,6 +58,10 @@ def _get(stage, p, info): repo = info.pop(DependencyREPO.PARAM_REPO) return DependencyREPO(repo, stage, p, info) + if info and info.get(DependencyPARAMS.PARAM_PARAMS): + params = info.pop(DependencyPARAMS.PARAM_PARAMS) + return DependencyPARAMS(stage, p, params) + for d in DEPS: if d.supported(p): return d(stage, p, info) @@ -65,7 +71,7 @@ def _get(stage, p, info): def loadd_from(stage, d_list): ret = [] for d in d_list: - p = d.pop(OutputBase.PARAM_PATH) + p = d.pop(OutputBase.PARAM_PATH, None) ret.append(_get(stage, p, d)) return ret @@ -76,3 +82,28 @@ def loads_from(stage, s_list, erepo=None): info = {DependencyREPO.PARAM_REPO: erepo} if erepo else {} ret.append(_get(stage, s, info)) return ret + + +def _parse_params(path_params): + path, _, params_str = path_params.rpartition(":") + params = params_str.split(",") + return path, params + + +def loads_params(stage, s_list): + # Creates an object for each unique file that is referenced in the list + params_by_path = defaultdict(list) + for s in s_list: + path, params = _parse_params(s) + params_by_path[path].extend(params) + + d_list = [] + for path, params in params_by_path.items(): + d_list.append( + { + OutputBase.PARAM_PATH: path, + DependencyPARAMS.PARAM_PARAMS: params, + } + ) + + return loadd_from(stage, d_list) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py new file mode 100644 index 0000000000..c5cd4a2b34 --- /dev/null +++ b/dvc/dependency/param.py @@ -0,0 +1,105 @@ +import os +import yaml +from collections import defaultdict + +from voluptuous import Any +from funcy import select_keys +from flatten_json import flatten + +from dvc.compat import fspath_py35 +from dvc.dependency.local import DependencyLOCAL +from dvc.exceptions import DvcException + + +class MissingParamsError(DvcException): + pass + + +class BadParamFileError(DvcException): + pass + + +class DependencyPARAMS(DependencyLOCAL): + PARAM_PARAMS = "params" + PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)} + DEFAULT_PARAMS_FILE = "params.yaml" + + def __init__(self, stage, path, params): + info = {} + self.params = [] + if params: + if isinstance(params, list): + self.params = params + else: + assert isinstance(params, dict) + self.params = list(params.keys()) + info = params + + super().__init__( + stage, + path + or os.path.join(stage.repo.root_dir, self.DEFAULT_PARAMS_FILE), + info=info, + ) + + def save(self): + super().save() + self.info = self.save_info() + + def status(self): + status = super().status() + + if status[str(self)] == "deleted": + return status + + status = defaultdict(dict) + info = self._get_info() + for param in self.params: + if param not in info.keys(): + st = "deleted" + elif param not in self.info: + st = "new" + elif info[param] != self.info[param]: + st = "modified" + else: + assert info[param] == self.info[param] + continue + + status[str(self)][param] = st + + return status + + def dumpd(self): + return { + self.PARAM_PATH: self.def_path, + self.PARAM_PARAMS: self.info or self.params, + } + + def _get_info(self): + if not self.exists: + return {} + + with open(fspath_py35(self.path_info), "r") as fobj: + try: + config = yaml.safe_load(fobj) + except yaml.YAMLError as exc: + raise BadParamFileError( + "Unable to read parameters from '{}'".format(self) + ) from exc + + config = flatten(config, ".") + + return select_keys(lambda key: key in self.params, config) + + def save_info(self): + info = self._get_info() + + missing_params = set(self.params) - set(info.keys()) + if missing_params: + raise MissingParamsError( + "Parameters '{}' are missing from '{}'.".format( + ", ".join(missing_params), self, + ) + ) + + return info diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 53edac176b..56292c4272 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -14,9 +14,11 @@ class DependencyREPO(DependencyLOCAL): PARAM_REV_LOCK = "rev_lock" REPO_SCHEMA = { - Required(PARAM_URL): str, - PARAM_REV: str, - PARAM_REV_LOCK: str, + PARAM_REPO: { + Required(PARAM_URL): str, + PARAM_REV: str, + PARAM_REV_LOCK: str, + } } def __init__(self, def_repo, stage, *args, **kwargs): diff --git a/dvc/output/base.py b/dvc/output/base.py index d07010ad6a..f6f7856de5 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -159,6 +159,9 @@ def is_dir_checksum(self): def exists(self): return self.remote.exists(self.path_info) + def save_info(self): + return self.remote.save_info(self.path_info) + def changed_checksum(self): return self.checksum != self.remote.get_checksum(self.path_info) @@ -215,7 +218,7 @@ def save(self): self.repo.scm.ignore(self.fspath) if not self.use_cache: - self.info = self.remote.save_info(self.path_info) + self.info = self.save_info() if self.metric: self.verify_metric() if not self.IS_DEPENDENCY: @@ -234,7 +237,7 @@ def save(self): ) return - self.info = self.remote.save_info(self.path_info) + self.info = self.save_info() def commit(self): if self.use_cache: diff --git a/dvc/stage.py b/dvc/stage.py index a3e015a6be..9d8ac9225b 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -533,9 +533,11 @@ def create(repo, accompany_outs=False, **kwargs): ) Stage._fill_stage_outputs(stage, **kwargs) - stage.deps = dependency.loads_from( + deps = dependency.loads_from( stage, kwargs.get("deps", []), erepo=kwargs.get("erepo", None) ) + params = dependency.loads_params(stage, kwargs.get("params", [])) + stage.deps = deps + params stage._check_circular_dependency() stage._check_duplicated_arguments() diff --git a/tests/unit/command/test_run.py b/tests/unit/command/test_run.py index f11b8e7a18..2b6dcfe659 100644 --- a/tests/unit/command/test_run.py +++ b/tests/unit/command/test_run.py @@ -33,6 +33,10 @@ def test_run(mocker, dvc): "--outs-persist-no-cache", "outs-persist-no-cache", "--always-changed", + "--params", + "file:param1,param2", + "--params", + "param3", "command", ] ) @@ -51,6 +55,7 @@ def test_run(mocker, dvc): metrics_no_cache=["metrics-no-cache"], outs_persist=["outs-persist"], outs_persist_no_cache=["outs-persist-no-cache"], + params=["file:param1,param2", "param3"], fname="file", cwd="cwd", wdir="wdir", @@ -77,6 +82,7 @@ def test_run_args_from_cli(mocker, dvc): metrics_no_cache=[], outs_persist=[], outs_persist_no_cache=[], + params=[], fname=None, cwd=None, wdir=None, @@ -103,6 +109,7 @@ def test_run_args_with_spaces(mocker, dvc): metrics_no_cache=[], outs_persist=[], outs_persist_no_cache=[], + params=[], fname=None, cwd=None, wdir=None, diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py new file mode 100644 index 0000000000..72d82d099a --- /dev/null +++ b/tests/unit/dependency/test_params.py @@ -0,0 +1,81 @@ +import yaml + +import pytest + +from dvc.dependency import DependencyPARAMS, loads_params, loadd_from +from dvc.dependency.param import BadParamFileError, MissingParamsError +from dvc.stage import Stage + + +PARAMS = { + "foo": 1, + "bar": 53.135, + "baz": "str", + "qux": None, +} + + +def test_loads_params(dvc): + stage = Stage(dvc) + deps = loads_params(stage, ["foo", "bar,baz", "a_file:qux"]) + assert len(deps) == 2 + + assert isinstance(deps[0], DependencyPARAMS) + assert deps[0].def_path == DependencyPARAMS.DEFAULT_PARAMS_FILE + assert deps[0].params == ["foo", "bar", "baz"] + assert deps[0].info == {} + + assert isinstance(deps[1], DependencyPARAMS) + assert deps[1].def_path == "a_file" + assert deps[1].params == ["qux"] + assert deps[1].info == {} + + +def test_loadd_from(dvc): + stage = Stage(dvc) + deps = loadd_from(stage, [{"params": PARAMS}]) + assert len(deps) == 1 + assert isinstance(deps[0], DependencyPARAMS) + assert deps[0].def_path == DependencyPARAMS.DEFAULT_PARAMS_FILE + assert deps[0].params == list(PARAMS.keys()) + assert deps[0].info == PARAMS + + +def test_dumpd_with_info(dvc): + dep = DependencyPARAMS(Stage(dvc), None, PARAMS) + assert dep.dumpd() == { + "path": "params.yaml", + "params": PARAMS, + } + + +def test_dumpd_without_info(dvc): + dep = DependencyPARAMS(Stage(dvc), None, list(PARAMS.keys())) + assert dep.dumpd() == { + "path": "params.yaml", + "params": list(PARAMS.keys()), + } + + +def test_get_info_nonexistent_file(dvc): + dep = DependencyPARAMS(Stage(dvc), None, ["foo"]) + assert dep._get_info() == {} + + +def test_get_info_unsupported_format(tmp_dir, dvc): + tmp_dir.gen("params.yaml", b"\0\1\2\3\4\5\6\7") + dep = DependencyPARAMS(Stage(dvc), None, ["foo"]) + with pytest.raises(BadParamFileError): + dep._get_info() + + +def test_get_info_nested(tmp_dir, dvc): + tmp_dir.gen("params.yaml", yaml.dump({"some": {"path": {"foo": "val"}}})) + dep = DependencyPARAMS(Stage(dvc), None, ["some.path.foo"]) + assert dep._get_info() == {"some.path.foo": "val"} + + +def test_save_info_missing_params(dvc): + dep = DependencyPARAMS(Stage(dvc), None, ["foo"]) + with pytest.raises(MissingParamsError): + dep.save_info()