From fa9fa0b2ae424849af27e598385cdabdb9f6bd12 Mon Sep 17 00:00:00 2001 From: elgehelge Date: Thu, 12 Mar 2020 22:54:31 +0100 Subject: [PATCH 01/11] Work in progress --- dvc/command/run.py | 8 ++++++++ dvc/dependency/__init__.py | 28 +++++++++++++++++----------- dvc/dependency/param.py | 28 ++++++++++++++++++++++++++++ dvc/dependency/repo.py | 4 ++-- dvc/output/base.py | 4 ++++ dvc/stage.py | 8 ++++++-- tests/func/test_run.py | 4 +++- 7 files changed, 68 insertions(+), 16 deletions(-) create mode 100644 dvc/dependency/param.py diff --git a/dvc/command/run.py b/dvc/command/run.py index aa4f7ded5a..4b33a468ff 100644 --- a/dvc/command/run.py +++ b/dvc/command/run.py @@ -40,6 +40,7 @@ def run(self): metrics=self.args.metrics, metrics_no_cache=self.args.metrics_no_cache, deps=self.args.deps, + params=self.args.params, fname=self.args.file, cwd=self.args.cwd, wdir=self.args.wdir, @@ -111,6 +112,13 @@ def add_parser(subparsers, parent_parser): help="Declare output file or directory " "(do not put into DVC cache).", ) + run_parser.add_argument( + "-p", + "--params", + action="append", + default=[], + help="Declare parameter to use as additional dependency.", + ) run_parser.add_argument( "-m", "--metrics", diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index a968f01443..bb0ddff774 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -8,6 +8,7 @@ from dvc.dependency.local import DependencyLOCAL from dvc.dependency.s3 import DependencyS3 from dvc.dependency.ssh import DependencySSH +from dvc.dependency.param import DependencyPARAM from dvc.output.base import OutputBase from dvc.remote import Remote from dvc.scheme import Schemes @@ -42,37 +43,42 @@ SCHEMA = output.SCHEMA.copy() del SCHEMA[OutputBase.PARAM_CACHE] del SCHEMA[OutputBase.PARAM_METRIC] -SCHEMA[DependencyREPO.PARAM_REPO] = DependencyREPO.REPO_SCHEMA +SCHEMA.update(DependencyREPO.REPO_SCHEMA) +SCHEMA.update(DependencyPARAM.PARAM_SCHEMA) -def _get(stage, p, info): - parsed = urlparse(p) +def _get_by_path(stage, path, info): + parsed = urlparse(path) if parsed.scheme == "remote": remote = Remote(stage.repo, name=parsed.netloc) - return DEP_MAP[remote.scheme](stage, p, info, remote=remote) + return DEP_MAP[remote.scheme](stage, path, info, remote=remote) if info and info.get(DependencyREPO.PARAM_REPO): repo = info.pop(DependencyREPO.PARAM_REPO) - return DependencyREPO(repo, stage, p, info) + return DependencyREPO(repo, stage, path, info) for d in DEPS: - if d.supported(p): - return d(stage, p, info) - return DependencyLOCAL(stage, p, info) + if d.supported(path): + return d(stage, path, info) + return DependencyLOCAL(stage, path, info) def loadd_from(stage, d_list): ret = [] for d in d_list: p = d.pop(OutputBase.PARAM_PATH) - ret.append(_get(stage, p, d)) + ret.append(_get_by_path(stage, p, d)) return ret -def loads_from(stage, s_list, erepo=None): +def loads_from(stage, s_list, erepo=None, is_param=False): ret = [] for s in s_list: info = {DependencyREPO.PARAM_REPO: erepo} if erepo else {} - ret.append(_get(stage, s, info)) + if is_param: + dep_obj = DependencyPARAM(stage, s, info) + else: + dep_obj = _get_by_path(stage, s, info) + ret.append(dep_obj) return ret diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py new file mode 100644 index 0000000000..e8af902b3d --- /dev/null +++ b/dvc/dependency/param.py @@ -0,0 +1,28 @@ +from copy import copy + +from dvc.dependency.local import DependencyLOCAL + + +class DependencyPARAM(DependencyLOCAL): + # SCHEMA: + # - path: + # params: + # - : + PARAM_PARAMS = "params" + PARAM_SCHEMA = {DependencyLOCAL.PARAM_PATH: str, PARAM_PARAMS: {str: str}} + + def __init__(self, stage, path_and_param_name, *args, **kwargs): + # TODO: Verify format (no more than one ":", and more I guess) + # TODO: If no file is given, use a default + path, param_name = path_and_param_name.split(':') + self.param_name = param_name + super().__init__(stage, path, *args, **kwargs) + self.def_path = self.def_path + ':' + param_name # TODO: Not sure about this + + @property + def unique_identifier(self): + return self.param_name + + def save(self): + # TODO: Verify exists (parse file and get value) + super().save() \ No newline at end of file diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 53edac176b..48e214e2d6 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -13,11 +13,11 @@ class DependencyREPO(DependencyLOCAL): PARAM_REV = "rev" PARAM_REV_LOCK = "rev_lock" - REPO_SCHEMA = { + REPO_SCHEMA = {PARAM_REPO: { Required(PARAM_URL): str, PARAM_REV: str, PARAM_REV_LOCK: str, - } + }} def __init__(self, def_repo, stage, *args, **kwargs): self.def_repo = def_repo diff --git a/dvc/output/base.py b/dvc/output/base.py index d07010ad6a..405ce71dd2 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -159,6 +159,10 @@ def is_dir_checksum(self): def exists(self): return self.remote.exists(self.path_info) + @property + def unique_identifier(self): + return self.path_info + def changed_checksum(self): return self.checksum != self.remote.get_checksum(self.path_info) diff --git a/dvc/stage.py b/dvc/stage.py index a3e015a6be..55593dc679 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -533,9 +533,13 @@ def create(repo, accompany_outs=False, **kwargs): ) Stage._fill_stage_outputs(stage, **kwargs) - stage.deps = dependency.loads_from( + deps = dependency.loads_from( stage, kwargs.get("deps", []), erepo=kwargs.get("erepo", None) ) + params = dependency.loads_from( + stage, kwargs.get("params", []), erepo=kwargs.get("erepo", None), + is_param=True) + stage.deps = deps + params stage._check_circular_dependency() stage._check_duplicated_arguments() @@ -861,7 +865,7 @@ def _check_duplicated_arguments(self): from dvc.exceptions import ArgumentDuplicationError from collections import Counter - path_counts = Counter(edge.path_info for edge in self.deps + self.outs) + path_counts = Counter(edge.unique_identifier for edge in self.deps + self.outs) for path, occurrence in path_counts.items(): if occurrence > 1: diff --git a/tests/func/test_run.py b/tests/func/test_run.py index 0f97672a67..8d2dd3c7ac 100644 --- a/tests/func/test_run.py +++ b/tests/func/test_run.py @@ -35,6 +35,7 @@ class TestRun(TestDvc): def test(self): cmd = "python {} {} {}".format(self.CODE, self.FOO, "out") deps = [self.FOO, self.CODE] + params = ["foo:some_param"] outs = [os.path.join(self.dvc.root_dir, "out")] outs_no_cache = [] fname = "out.dvc" @@ -45,6 +46,7 @@ def test(self): cmd=cmd, deps=deps, outs=outs, + params=params, outs_no_cache=outs_no_cache, fname=fname, cwd=cwd, @@ -53,7 +55,7 @@ def test(self): self.assertTrue(filecmp.cmp(self.FOO, "out", shallow=False)) self.assertTrue(os.path.isfile(stage.path)) self.assertEqual(stage.cmd, cmd) - self.assertEqual(len(stage.deps), len(deps)) + self.assertEqual(len(stage.deps), len(deps) + len(params)) self.assertEqual(len(stage.outs), len(outs + outs_no_cache)) self.assertEqual(stage.outs[0].fspath, outs[0]) self.assertEqual(stage.outs[0].checksum, file_md5(self.FOO)[0]) From 4c680473f5faeb383816f1cc913c0057601aa82a Mon Sep 17 00:00:00 2001 From: elgehelge Date: Mon, 16 Mar 2020 14:51:37 +0100 Subject: [PATCH 02/11] added file parsing and name validation + adjust schema --- dvc/dependency/param.py | 61 +++++++++++++++++++++++++++++++++-------- tests/basic_env.py | 3 ++ tests/func/test_run.py | 2 +- 3 files changed, 54 insertions(+), 12 deletions(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index e8af902b3d..b859a8d785 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -1,28 +1,67 @@ -from copy import copy +import json +import re from dvc.dependency.local import DependencyLOCAL class DependencyPARAM(DependencyLOCAL): # SCHEMA: - # - path: # params: # - : PARAM_PARAMS = "params" - PARAM_SCHEMA = {DependencyLOCAL.PARAM_PATH: str, PARAM_PARAMS: {str: str}} + # TODO: Combine parameter deps across multiple param deps + PARAM_SCHEMA = {PARAM_PARAMS: {str: str}} + DELIMITER = ':' + DEFAULT_PARAMS_FILE = 'PARAMS.json' + PARAM_NAME_REGEX = re.compile(r'^\w+$') def __init__(self, stage, path_and_param_name, *args, **kwargs): - # TODO: Verify format (no more than one ":", and more I guess) - # TODO: If no file is given, use a default - path, param_name = path_and_param_name.split(':') - self.param_name = param_name + path, _, param_name = path_and_param_name.rpartition(self.DELIMITER) + path = path or self.DEFAULT_PARAMS_FILE + if not self._is_valid_name(param_name): + raise NotImplementedError() # TODO: raise BadParamNameError() ? super().__init__(stage, path, *args, **kwargs) - self.def_path = self.def_path + ':' + param_name # TODO: Not sure about this + self.param_name = param_name + self.param_value = self._parse()[param_name] + + def __str__(self): + path = super().__str__() + return path + ':' + self.param_name @property def unique_identifier(self): return self.param_name - def save(self): - # TODO: Verify exists (parse file and get value) - super().save() \ No newline at end of file + # def save(self): + # # TODO: Do we need to do anything different regarding `save()`? + # super().save() + + def dumpd(self): + return { + self.PARAM_PATH: self.def_path, + self.PARAM_PARAMS: {self.param_name: self.param_value}, + } + + @classmethod + def _is_valid_name(cls, param_name): + return cls.PARAM_NAME_REGEX.match(param_name) + + @property + def exists(self): + file_exists = super().exists + params = self._parse() + param_exists = self.param_name in params + return file_exists and param_exists + + def _parse(self): + try: + return self._params_cache + except AttributeError: + path = self.path_info.fspath + with open(path, 'r') as fp: + try: + self._params_cache = json.load(fp) + except json.JSONDecodeError: + # TODO raise BadParamFileError()? + raise NotImplementedError() + return self._params_cache diff --git a/tests/basic_env.py b/tests/basic_env.py index 119a599182..ddee87cd2e 100644 --- a/tests/basic_env.py +++ b/tests/basic_env.py @@ -38,6 +38,8 @@ class TestDirFixture(object): # in tests, we replace foo with bar, so we need to make sure that when we # modify a file in our tests, its content length changes. BAR_CONTENTS = BAR + "r" + PARAMS = "par.json" + PARAMS_CONTENTS = '{"someparam": "somevalue"}' CODE = "code.py" CODE_CONTENTS = ( "import sys\nimport shutil\n" @@ -87,6 +89,7 @@ def setUp(self): self._pushd(self._root_dir) self.create(self.FOO, self.FOO_CONTENTS) self.create(self.BAR, self.BAR_CONTENTS) + self.create(self.PARAMS, self.PARAMS_CONTENTS) self.create(self.CODE, self.CODE_CONTENTS) os.mkdir(self.DATA_DIR) os.mkdir(self.DATA_SUB_DIR) diff --git a/tests/func/test_run.py b/tests/func/test_run.py index 8d2dd3c7ac..f2a9c8b2c8 100644 --- a/tests/func/test_run.py +++ b/tests/func/test_run.py @@ -35,7 +35,7 @@ class TestRun(TestDvc): def test(self): cmd = "python {} {} {}".format(self.CODE, self.FOO, "out") deps = [self.FOO, self.CODE] - params = ["foo:some_param"] + params = ["par.json:someparam"] outs = [os.path.join(self.dvc.root_dir, "out")] outs_no_cache = [] fname = "out.dvc" From 4e23e7abfea440bfd5f0f5bef1027cc841fe1249 Mon Sep 17 00:00:00 2001 From: elgehelge Date: Mon, 16 Mar 2020 22:12:27 +0100 Subject: [PATCH 03/11] Exceptions on bad input --- dvc/dependency/param.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index b859a8d785..b1a9add3bc 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -2,6 +2,19 @@ import re from dvc.dependency.local import DependencyLOCAL +from dvc.exceptions import DvcException + + +class BadParamNameError(DvcException): + def __init__(self, param_name): + msg = "Parameter name '{}' is not allowed".format(param_name) + super().__init__(msg) + + +class BadParamFileError(DvcException): + def __init__(self, path): + msg = "Parameter file '{}' could not be read".format(path) + super().__init__(msg) class DependencyPARAM(DependencyLOCAL): @@ -19,10 +32,10 @@ def __init__(self, stage, path_and_param_name, *args, **kwargs): path, _, param_name = path_and_param_name.rpartition(self.DELIMITER) path = path or self.DEFAULT_PARAMS_FILE if not self._is_valid_name(param_name): - raise NotImplementedError() # TODO: raise BadParamNameError() ? + raise BadParamNameError(param_name) super().__init__(stage, path, *args, **kwargs) self.param_name = param_name - self.param_value = self._parse()[param_name] + self.param_value = None def __str__(self): path = super().__str__() @@ -32,9 +45,9 @@ def __str__(self): def unique_identifier(self): return self.param_name - # def save(self): - # # TODO: Do we need to do anything different regarding `save()`? - # super().save() + def save(self): + self.param_value = self._parse()[self.param_name] + super().save() # TODO: Not sure if this is needed def dumpd(self): return { @@ -62,6 +75,5 @@ def _parse(self): try: self._params_cache = json.load(fp) except json.JSONDecodeError: - # TODO raise BadParamFileError()? - raise NotImplementedError() + raise BadParamFileError(path) return self._params_cache From 0c54a16a379a2f32958c7b01c91d522cd07915e5 Mon Sep 17 00:00:00 2001 From: elgehelge Date: Mon, 16 Mar 2020 23:49:34 +0100 Subject: [PATCH 04/11] Support multiple parameters --- dvc/dependency/__init__.py | 6 ++--- dvc/dependency/param.py | 48 ++++++++++++++++++++------------------ dvc/output/base.py | 4 ---- dvc/stage.py | 2 +- 4 files changed, 29 insertions(+), 31 deletions(-) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index bb0ddff774..ace6ca8118 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -8,7 +8,7 @@ from dvc.dependency.local import DependencyLOCAL from dvc.dependency.s3 import DependencyS3 from dvc.dependency.ssh import DependencySSH -from dvc.dependency.param import DependencyPARAM +from dvc.dependency.param import DependencyPARAMS from dvc.output.base import OutputBase from dvc.remote import Remote from dvc.scheme import Schemes @@ -44,7 +44,7 @@ del SCHEMA[OutputBase.PARAM_CACHE] del SCHEMA[OutputBase.PARAM_METRIC] SCHEMA.update(DependencyREPO.REPO_SCHEMA) -SCHEMA.update(DependencyPARAM.PARAM_SCHEMA) +SCHEMA.update(DependencyPARAMS.PARAM_SCHEMA) def _get_by_path(stage, path, info): @@ -77,7 +77,7 @@ def loads_from(stage, s_list, erepo=None, is_param=False): for s in s_list: info = {DependencyREPO.PARAM_REPO: erepo} if erepo else {} if is_param: - dep_obj = DependencyPARAM(stage, s, info) + dep_obj = DependencyPARAMS(stage, s, info) else: dep_obj = _get_by_path(stage, s, info) ret.append(dep_obj) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index b1a9add3bc..bf7de7716d 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -7,7 +7,7 @@ class BadParamNameError(DvcException): def __init__(self, param_name): - msg = "Parameter name '{}' is not allowed".format(param_name) + msg = "Parameter name '{}' is not valid".format(param_name) super().__init__(msg) @@ -17,54 +17,56 @@ def __init__(self, path): super().__init__(msg) -class DependencyPARAM(DependencyLOCAL): +class DependencyPARAMS(DependencyLOCAL): # SCHEMA: # params: # - : + # - : PARAM_PARAMS = "params" - # TODO: Combine parameter deps across multiple param deps PARAM_SCHEMA = {PARAM_PARAMS: {str: str}} - DELIMITER = ':' + FILE_DELIMITER = ':' + PARAM_DELIMITER = ',' DEFAULT_PARAMS_FILE = 'PARAMS.json' - PARAM_NAME_REGEX = re.compile(r'^\w+$') - def __init__(self, stage, path_and_param_name, *args, **kwargs): - path, _, param_name = path_and_param_name.rpartition(self.DELIMITER) + REGEX_SUBNAME = r'\w+' + REGEX_NAME = r'{sub}(\.{sub})*'.format(sub=REGEX_SUBNAME) + REGEX_MULTI_PARAMS = r'^{param}(,{param})*$'.format(param=REGEX_NAME) + REGEX_COMPILED = re.compile(REGEX_MULTI_PARAMS) + + def __init__(self, stage, input_str, *args, **kwargs): + path, _, param_names = input_str.rpartition(self.FILE_DELIMITER) path = path or self.DEFAULT_PARAMS_FILE - if not self._is_valid_name(param_name): - raise BadParamNameError(param_name) + if not self._is_valid_name(param_names): + raise BadParamNameError(param_names) super().__init__(stage, path, *args, **kwargs) - self.param_name = param_name - self.param_value = None + self.param_names = sorted(param_names.split(self.PARAM_DELIMITER)) + self.param_values = {} def __str__(self): path = super().__str__() - return path + ':' + self.param_name - - @property - def unique_identifier(self): - return self.param_name + return path + ':' + self.PARAM_DELIMITER.join(self.param_names) def save(self): - self.param_value = self._parse()[self.param_name] - super().save() # TODO: Not sure if this is needed + super().save() + params_in_file = self._parse() + self.param_values = {k: params_in_file[k] for k in self.param_names} def dumpd(self): return { self.PARAM_PATH: self.def_path, - self.PARAM_PARAMS: {self.param_name: self.param_value}, + self.PARAM_PARAMS: self.param_values, } @classmethod def _is_valid_name(cls, param_name): - return cls.PARAM_NAME_REGEX.match(param_name) + return cls.REGEX_COMPILED.match(param_name) @property def exists(self): file_exists = super().exists - params = self._parse() - param_exists = self.param_name in params - return file_exists and param_exists + params_in_file = self._parse() + params_exists = all([p in params_in_file for p in self.param_names]) + return file_exists and params_exists def _parse(self): try: diff --git a/dvc/output/base.py b/dvc/output/base.py index 405ce71dd2..d07010ad6a 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -159,10 +159,6 @@ def is_dir_checksum(self): def exists(self): return self.remote.exists(self.path_info) - @property - def unique_identifier(self): - return self.path_info - def changed_checksum(self): return self.checksum != self.remote.get_checksum(self.path_info) diff --git a/dvc/stage.py b/dvc/stage.py index 55593dc679..e95ada76db 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -865,7 +865,7 @@ def _check_duplicated_arguments(self): from dvc.exceptions import ArgumentDuplicationError from collections import Counter - path_counts = Counter(edge.unique_identifier for edge in self.deps + self.outs) + path_counts = Counter(edge.path_info for edge in self.deps + self.outs) for path, occurrence in path_counts.items(): if occurrence > 1: From 01ec8e07334e99184504f2db4b2e0599ec70203d Mon Sep 17 00:00:00 2001 From: elgehelge Date: Thu, 19 Mar 2020 23:24:49 +0100 Subject: [PATCH 05/11] Support multi 's in Having any troubles? Hit us up at https://dvc.org/support, we are always happy to help! --- dvc/dependency/__init__.py | 11 +++--- dvc/dependency/param.py | 52 +++++++++++++++++++++------- dvc/stage.py | 4 +-- tests/basic_env.py | 5 ++- tests/func/test_run.py | 4 +-- tests/unit/dependency/test_params.py | 17 +++++++++ 6 files changed, 69 insertions(+), 24 deletions(-) create mode 100644 tests/unit/dependency/test_params.py diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index ace6ca8118..9d2e93e00a 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -72,13 +72,14 @@ def loadd_from(stage, d_list): return ret -def loads_from(stage, s_list, erepo=None, is_param=False): +def loads_from(stage, s_list, erepo=None): ret = [] for s in s_list: info = {DependencyREPO.PARAM_REPO: erepo} if erepo else {} - if is_param: - dep_obj = DependencyPARAMS(stage, s, info) - else: - dep_obj = _get_by_path(stage, s, info) + dep_obj = _get_by_path(stage, s, info) ret.append(dep_obj) return ret + + +def loads_params(stage, s_list): # TODO: Make support for `eropo=` as well ? + return DependencyPARAMS.from_list(stage, s_list) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index bf7de7716d..33d59fecb1 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -1,5 +1,6 @@ import json import re +from itertools import groupby from dvc.dependency.local import DependencyLOCAL from dvc.exceptions import DvcException @@ -26,7 +27,7 @@ class DependencyPARAMS(DependencyLOCAL): PARAM_SCHEMA = {PARAM_PARAMS: {str: str}} FILE_DELIMITER = ':' PARAM_DELIMITER = ',' - DEFAULT_PARAMS_FILE = 'PARAMS.json' + DEFAULT_PARAMS_FILE = 'params.json' REGEX_SUBNAME = r'\w+' REGEX_NAME = r'{sub}(\.{sub})*'.format(sub=REGEX_SUBNAME) @@ -34,21 +35,50 @@ class DependencyPARAMS(DependencyLOCAL): REGEX_COMPILED = re.compile(REGEX_MULTI_PARAMS) def __init__(self, stage, input_str, *args, **kwargs): - path, _, param_names = input_str.rpartition(self.FILE_DELIMITER) - path = path or self.DEFAULT_PARAMS_FILE - if not self._is_valid_name(param_names): - raise BadParamNameError(param_names) + path, param_names = self._parse_and_validate_input(input_str) super().__init__(stage, path, *args, **kwargs) self.param_names = sorted(param_names.split(self.PARAM_DELIMITER)) self.param_values = {} def __str__(self): path = super().__str__() - return path + ':' + self.PARAM_DELIMITER.join(self.param_names) + return self._reverse_parse_input(path, self.param_names) + + @classmethod + def from_list(cls, stage, s_list): + # Creates an object for each unique file that is referenced in the list + ret = [] + pathname_tuples = [cls._parse_and_validate_input(s) for s in s_list] + grouped_by_path = groupby(sorted(pathname_tuples), key=lambda x: x[0]) + for path, group in grouped_by_path: + param_names = [g[1] for g in group] + regrouped_input = cls._reverse_parse_input(path, param_names) + ret.append(DependencyPARAMS(stage, regrouped_input)) + return ret + + @classmethod + def _parse_and_validate_input(cls, input_str): + path, _, param_names = input_str.rpartition(cls.FILE_DELIMITER) + cls._validate_input(param_names) + path = path or cls.DEFAULT_PARAMS_FILE + return path, param_names + + @classmethod + def _reverse_parse_input(cls, path, param_names): + return '{path}{delimiter}{params}'.format( + path=path, + delimiter=cls.FILE_DELIMITER, + params=cls.PARAM_DELIMITER.join(param_names), + ) + + @classmethod + def _validate_input(cls, param_names): + if not cls.REGEX_COMPILED.match(param_names): + raise BadParamNameError(param_names) def save(self): super().save() - params_in_file = self._parse() + params_in_file = self._parse_file() self.param_values = {k: params_in_file[k] for k in self.param_names} def dumpd(self): @@ -57,18 +87,14 @@ def dumpd(self): self.PARAM_PARAMS: self.param_values, } - @classmethod - def _is_valid_name(cls, param_name): - return cls.REGEX_COMPILED.match(param_name) - @property def exists(self): file_exists = super().exists - params_in_file = self._parse() + params_in_file = self._parse_file() params_exists = all([p in params_in_file for p in self.param_names]) return file_exists and params_exists - def _parse(self): + def _parse_file(self): try: return self._params_cache except AttributeError: diff --git a/dvc/stage.py b/dvc/stage.py index e95ada76db..9d8ac9225b 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -536,9 +536,7 @@ def create(repo, accompany_outs=False, **kwargs): deps = dependency.loads_from( stage, kwargs.get("deps", []), erepo=kwargs.get("erepo", None) ) - params = dependency.loads_from( - stage, kwargs.get("params", []), erepo=kwargs.get("erepo", None), - is_param=True) + params = dependency.loads_params(stage, kwargs.get("params", [])) stage.deps = deps + params stage._check_circular_dependency() diff --git a/tests/basic_env.py b/tests/basic_env.py index ddee87cd2e..49353070a1 100644 --- a/tests/basic_env.py +++ b/tests/basic_env.py @@ -38,8 +38,10 @@ class TestDirFixture(object): # in tests, we replace foo with bar, so we need to make sure that when we # modify a file in our tests, its content length changes. BAR_CONTENTS = BAR + "r" + PARAMSDEFAULT = "params.json" + PARAMSDEFAULT_CONTENTS = '{"p_one": "1", "p_two": "1"}' PARAMS = "par.json" - PARAMS_CONTENTS = '{"someparam": "somevalue"}' + PARAMS_CONTENTS = '{"p_three": "3"}' CODE = "code.py" CODE_CONTENTS = ( "import sys\nimport shutil\n" @@ -89,6 +91,7 @@ def setUp(self): self._pushd(self._root_dir) self.create(self.FOO, self.FOO_CONTENTS) self.create(self.BAR, self.BAR_CONTENTS) + self.create(self.PARAMSDEFAULT, self.PARAMSDEFAULT_CONTENTS) self.create(self.PARAMS, self.PARAMS_CONTENTS) self.create(self.CODE, self.CODE_CONTENTS) os.mkdir(self.DATA_DIR) diff --git a/tests/func/test_run.py b/tests/func/test_run.py index f2a9c8b2c8..8e1bc195ce 100644 --- a/tests/func/test_run.py +++ b/tests/func/test_run.py @@ -35,7 +35,7 @@ class TestRun(TestDvc): def test(self): cmd = "python {} {} {}".format(self.CODE, self.FOO, "out") deps = [self.FOO, self.CODE] - params = ["par.json:someparam"] + params = ["p_one", "p_two", "par.json:p_three"] outs = [os.path.join(self.dvc.root_dir, "out")] outs_no_cache = [] fname = "out.dvc" @@ -55,7 +55,7 @@ def test(self): self.assertTrue(filecmp.cmp(self.FOO, "out", shallow=False)) self.assertTrue(os.path.isfile(stage.path)) self.assertEqual(stage.cmd, cmd) - self.assertEqual(len(stage.deps), len(deps) + len(params)) + self.assertEqual(len(stage.deps), len(deps) + 2) self.assertEqual(len(stage.outs), len(outs + outs_no_cache)) self.assertEqual(stage.outs[0].fspath, outs[0]) self.assertEqual(stage.outs[0].checksum, file_md5(self.FOO)[0]) diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py new file mode 100644 index 0000000000..75da90208f --- /dev/null +++ b/tests/unit/dependency/test_params.py @@ -0,0 +1,17 @@ +import mock + +from dvc.dependency import DependencyPARAMS +from dvc.stage import Stage +from tests.basic_env import TestDvc + + +class TestDependencyPARAM(TestDvc): + def test_from_list(self): + stage = Stage(self.dvc) + deps = DependencyPARAMS.from_list(stage, ['foo', 'bar,baz', + 'a_file:qux']) + assert len(deps) == 2 + assert deps[0].def_path == "a_file" + assert deps[0].param_names == ["qux"] + assert deps[1].def_path == DependencyPARAMS.DEFAULT_PARAMS_FILE + assert deps[1].param_names == ["bar", "baz", "foo"] From 3f9db9c33df092b4f79ce9ea9dec71e420cd6cab Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Fri, 20 Mar 2020 00:56:31 +0200 Subject: [PATCH 06/11] fix formatting --- dvc/dependency/param.py | 16 ++++++++-------- dvc/dependency/repo.py | 12 +++++++----- tests/unit/dependency/test_params.py | 7 +++---- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 33d59fecb1..f44139579c 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -25,13 +25,13 @@ class DependencyPARAMS(DependencyLOCAL): # - : PARAM_PARAMS = "params" PARAM_SCHEMA = {PARAM_PARAMS: {str: str}} - FILE_DELIMITER = ':' - PARAM_DELIMITER = ',' - DEFAULT_PARAMS_FILE = 'params.json' + FILE_DELIMITER = ":" + PARAM_DELIMITER = "," + DEFAULT_PARAMS_FILE = "params.json" - REGEX_SUBNAME = r'\w+' - REGEX_NAME = r'{sub}(\.{sub})*'.format(sub=REGEX_SUBNAME) - REGEX_MULTI_PARAMS = r'^{param}(,{param})*$'.format(param=REGEX_NAME) + REGEX_SUBNAME = r"\w+" + REGEX_NAME = r"{sub}(\.{sub})*".format(sub=REGEX_SUBNAME) + REGEX_MULTI_PARAMS = r"^{param}(,{param})*$".format(param=REGEX_NAME) REGEX_COMPILED = re.compile(REGEX_MULTI_PARAMS) def __init__(self, stage, input_str, *args, **kwargs): @@ -65,7 +65,7 @@ def _parse_and_validate_input(cls, input_str): @classmethod def _reverse_parse_input(cls, path, param_names): - return '{path}{delimiter}{params}'.format( + return "{path}{delimiter}{params}".format( path=path, delimiter=cls.FILE_DELIMITER, params=cls.PARAM_DELIMITER.join(param_names), @@ -99,7 +99,7 @@ def _parse_file(self): return self._params_cache except AttributeError: path = self.path_info.fspath - with open(path, 'r') as fp: + with open(path, "r") as fp: try: self._params_cache = json.load(fp) except json.JSONDecodeError: diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 48e214e2d6..56292c4272 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -13,11 +13,13 @@ class DependencyREPO(DependencyLOCAL): PARAM_REV = "rev" PARAM_REV_LOCK = "rev_lock" - REPO_SCHEMA = {PARAM_REPO: { - Required(PARAM_URL): str, - PARAM_REV: str, - PARAM_REV_LOCK: str, - }} + REPO_SCHEMA = { + PARAM_REPO: { + Required(PARAM_URL): str, + PARAM_REV: str, + PARAM_REV_LOCK: str, + } + } def __init__(self, def_repo, stage, *args, **kwargs): self.def_repo = def_repo diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index 75da90208f..676e6c856c 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -1,5 +1,3 @@ -import mock - from dvc.dependency import DependencyPARAMS from dvc.stage import Stage from tests.basic_env import TestDvc @@ -8,8 +6,9 @@ class TestDependencyPARAM(TestDvc): def test_from_list(self): stage = Stage(self.dvc) - deps = DependencyPARAMS.from_list(stage, ['foo', 'bar,baz', - 'a_file:qux']) + deps = DependencyPARAMS.from_list( + stage, ["foo", "bar,baz", "a_file:qux"] + ) assert len(deps) == 2 assert deps[0].def_path == "a_file" assert deps[0].param_names == ["qux"] From 3eddcf573f8fc21a5892e8b99032a164fb0d63c3 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Fri, 20 Mar 2020 16:50:18 +0200 Subject: [PATCH 07/11] dep: param: fix schema Signed-off-by: Ruslan Kuprieiev --- dvc/dependency/param.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index f44139579c..fcd70d63e0 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -1,6 +1,7 @@ import json import re from itertools import groupby +from voluptuous import Any from dvc.dependency.local import DependencyLOCAL from dvc.exceptions import DvcException @@ -19,12 +20,8 @@ def __init__(self, path): class DependencyPARAMS(DependencyLOCAL): - # SCHEMA: - # params: - # - : - # - : PARAM_PARAMS = "params" - PARAM_SCHEMA = {PARAM_PARAMS: {str: str}} + PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, None)} FILE_DELIMITER = ":" PARAM_DELIMITER = "," DEFAULT_PARAMS_FILE = "params.json" From 4fbbaac5b2788110c936190f7807bb52a76dd071 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Fri, 20 Mar 2020 16:50:47 +0200 Subject: [PATCH 08/11] run: check if --params is specified --- dvc/command/run.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dvc/command/run.py b/dvc/command/run.py index 4b33a468ff..fdac98bcaa 100644 --- a/dvc/command/run.py +++ b/dvc/command/run.py @@ -22,13 +22,14 @@ def run(self): self.args.metrics_no_cache, self.args.outs_persist, self.args.outs_persist_no_cache, + self.args.params, self.args.command, ] ): # pragma: no cover logger.error( "too few arguments. Specify at least one: `-d`, `-o`, `-O`, " - "`-m`, `-M`, `--outs-persist`, `--outs-persist-no-cache`, " - "`command`." + "`-m`, `-M`, `-p`, `--outs-persist`, " + "`--outs-persist-no-cache`, `command`." ) return 1 From 12a6e9ab5d2eeee76d70383e46d2b7ed4d4aae40 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Sat, 21 Mar 2020 19:32:50 +0200 Subject: [PATCH 09/11] dvc: rework param handling --- dvc/dependency/__init__.py | 54 ++++++--- dvc/dependency/param.py | 160 +++++++++++++-------------- dvc/output/base.py | 7 +- tests/basic_env.py | 6 - tests/func/test_run.py | 4 +- tests/unit/command/test_run.py | 7 ++ tests/unit/dependency/test_params.py | 87 ++++++++++++--- 7 files changed, 203 insertions(+), 122 deletions(-) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 9d2e93e00a..f6346cdeec 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -1,4 +1,5 @@ from urllib.parse import urlparse +from collections import defaultdict import dvc.output as output from dvc.dependency.gs import DependencyGS @@ -47,28 +48,31 @@ SCHEMA.update(DependencyPARAMS.PARAM_SCHEMA) -def _get_by_path(stage, path, info): - parsed = urlparse(path) - - if parsed.scheme == "remote": +def _get(stage, p, info): + parsed = urlparse(p) if p else None + if parsed and parsed.scheme == "remote": remote = Remote(stage.repo, name=parsed.netloc) - return DEP_MAP[remote.scheme](stage, path, info, remote=remote) + return DEP_MAP[remote.scheme](stage, p, info, remote=remote) if info and info.get(DependencyREPO.PARAM_REPO): repo = info.pop(DependencyREPO.PARAM_REPO) - return DependencyREPO(repo, stage, path, info) + return DependencyREPO(repo, stage, p, info) + + if info and info.get(DependencyPARAMS.PARAM_PARAMS): + params = info.pop(DependencyPARAMS.PARAM_PARAMS) + return DependencyPARAMS(stage, p, params) for d in DEPS: - if d.supported(path): - return d(stage, path, info) - return DependencyLOCAL(stage, path, info) + if d.supported(p): + return d(stage, p, info) + return DependencyLOCAL(stage, p, info) def loadd_from(stage, d_list): ret = [] for d in d_list: - p = d.pop(OutputBase.PARAM_PATH) - ret.append(_get_by_path(stage, p, d)) + p = d.pop(OutputBase.PARAM_PATH, None) + ret.append(_get(stage, p, d)) return ret @@ -76,10 +80,30 @@ def loads_from(stage, s_list, erepo=None): ret = [] for s in s_list: info = {DependencyREPO.PARAM_REPO: erepo} if erepo else {} - dep_obj = _get_by_path(stage, s, info) - ret.append(dep_obj) + ret.append(_get(stage, s, info)) return ret -def loads_params(stage, s_list): # TODO: Make support for `eropo=` as well ? - return DependencyPARAMS.from_list(stage, s_list) +def _parse_params(path_params): + path, _, params_str = path_params.rpartition(":") + params = params_str.split(",") + return path, params + + +def loads_params(stage, s_list): + # Creates an object for each unique file that is referenced in the list + params_by_path = defaultdict(list) + for s in s_list: + path, params = _parse_params(s) + params_by_path[path].extend(params) + + d_list = [] + for path, params in params_by_path.items(): + d_list.append( + { + OutputBase.PARAM_PATH: path, + DependencyPARAMS.PARAM_PARAMS: params, + } + ) + + return loadd_from(stage, d_list) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index fcd70d63e0..94ba54b48e 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -1,104 +1,102 @@ -import json -import re -from itertools import groupby +import os +import yaml +from collections import defaultdict + from voluptuous import Any +from funcy import select_keys +from dvc.compat import fspath_py35 from dvc.dependency.local import DependencyLOCAL from dvc.exceptions import DvcException -class BadParamNameError(DvcException): - def __init__(self, param_name): - msg = "Parameter name '{}' is not valid".format(param_name) - super().__init__(msg) +class MissingParamsError(DvcException): + pass class BadParamFileError(DvcException): - def __init__(self, path): - msg = "Parameter file '{}' could not be read".format(path) - super().__init__(msg) + pass class DependencyPARAMS(DependencyLOCAL): PARAM_PARAMS = "params" - PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, None)} - FILE_DELIMITER = ":" - PARAM_DELIMITER = "," - DEFAULT_PARAMS_FILE = "params.json" - - REGEX_SUBNAME = r"\w+" - REGEX_NAME = r"{sub}(\.{sub})*".format(sub=REGEX_SUBNAME) - REGEX_MULTI_PARAMS = r"^{param}(,{param})*$".format(param=REGEX_NAME) - REGEX_COMPILED = re.compile(REGEX_MULTI_PARAMS) - - def __init__(self, stage, input_str, *args, **kwargs): - path, param_names = self._parse_and_validate_input(input_str) - super().__init__(stage, path, *args, **kwargs) - self.param_names = sorted(param_names.split(self.PARAM_DELIMITER)) - self.param_values = {} - - def __str__(self): - path = super().__str__() - return self._reverse_parse_input(path, self.param_names) - - @classmethod - def from_list(cls, stage, s_list): - # Creates an object for each unique file that is referenced in the list - ret = [] - pathname_tuples = [cls._parse_and_validate_input(s) for s in s_list] - grouped_by_path = groupby(sorted(pathname_tuples), key=lambda x: x[0]) - for path, group in grouped_by_path: - param_names = [g[1] for g in group] - regrouped_input = cls._reverse_parse_input(path, param_names) - ret.append(DependencyPARAMS(stage, regrouped_input)) - return ret - - @classmethod - def _parse_and_validate_input(cls, input_str): - path, _, param_names = input_str.rpartition(cls.FILE_DELIMITER) - cls._validate_input(param_names) - path = path or cls.DEFAULT_PARAMS_FILE - return path, param_names - - @classmethod - def _reverse_parse_input(cls, path, param_names): - return "{path}{delimiter}{params}".format( - path=path, - delimiter=cls.FILE_DELIMITER, - params=cls.PARAM_DELIMITER.join(param_names), + PARAM_SCHEMA = {PARAM_PARAMS: Any(dict, list, None)} + DEFAULT_PARAMS_FILE = "params.yaml" + + def __init__(self, stage, path, params): + info = {} + self.params = [] + if params: + if isinstance(params, list): + self.params = params + else: + assert isinstance(params, dict) + self.params = list(params.keys()) + info = params + + super().__init__( + stage, + path + or os.path.join(stage.repo.root_dir, self.DEFAULT_PARAMS_FILE), + info=info, ) - @classmethod - def _validate_input(cls, param_names): - if not cls.REGEX_COMPILED.match(param_names): - raise BadParamNameError(param_names) - def save(self): super().save() - params_in_file = self._parse_file() - self.param_values = {k: params_in_file[k] for k in self.param_names} + self.info = self.save_info() + + def status(self): + status = super().status() + + if status[str(self)] == "deleted": + return status + + status = defaultdict(dict) + info = self._get_info() + for param in self.params: + if param not in info.keys(): + st = "deleted" + elif param not in self.info: + st = "new" + elif info[param] != self.info[param]: + st = "modified" + else: + assert info[param] == self.info[param] + continue + + status[str(self)][param] = st + + return status def dumpd(self): return { self.PARAM_PATH: self.def_path, - self.PARAM_PARAMS: self.param_values, + self.PARAM_PARAMS: self.info or self.params, } - @property - def exists(self): - file_exists = super().exists - params_in_file = self._parse_file() - params_exists = all([p in params_in_file for p in self.param_names]) - return file_exists and params_exists - - def _parse_file(self): - try: - return self._params_cache - except AttributeError: - path = self.path_info.fspath - with open(path, "r") as fp: - try: - self._params_cache = json.load(fp) - except json.JSONDecodeError: - raise BadParamFileError(path) - return self._params_cache + def _get_info(self): + if not self.exists: + return {} + + with open(fspath_py35(self.path_info), "r") as fobj: + try: + config = yaml.safe_load(fobj) + except yaml.YAMLError as exc: + raise BadParamFileError( + "Unable to read parameters from '{}'".format(self) + ) from exc + + return select_keys(lambda key: key in self.params, config) + + def save_info(self): + info = self._get_info() + + missing_params = set(self.params) - set(info.keys()) + if missing_params: + raise MissingParamsError( + "Parameters '{}' are missing from '{}'.".format( + missing_params, self, + ) + ) + + return info diff --git a/dvc/output/base.py b/dvc/output/base.py index d07010ad6a..f6f7856de5 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -159,6 +159,9 @@ def is_dir_checksum(self): def exists(self): return self.remote.exists(self.path_info) + def save_info(self): + return self.remote.save_info(self.path_info) + def changed_checksum(self): return self.checksum != self.remote.get_checksum(self.path_info) @@ -215,7 +218,7 @@ def save(self): self.repo.scm.ignore(self.fspath) if not self.use_cache: - self.info = self.remote.save_info(self.path_info) + self.info = self.save_info() if self.metric: self.verify_metric() if not self.IS_DEPENDENCY: @@ -234,7 +237,7 @@ def save(self): ) return - self.info = self.remote.save_info(self.path_info) + self.info = self.save_info() def commit(self): if self.use_cache: diff --git a/tests/basic_env.py b/tests/basic_env.py index 49353070a1..119a599182 100644 --- a/tests/basic_env.py +++ b/tests/basic_env.py @@ -38,10 +38,6 @@ class TestDirFixture(object): # in tests, we replace foo with bar, so we need to make sure that when we # modify a file in our tests, its content length changes. BAR_CONTENTS = BAR + "r" - PARAMSDEFAULT = "params.json" - PARAMSDEFAULT_CONTENTS = '{"p_one": "1", "p_two": "1"}' - PARAMS = "par.json" - PARAMS_CONTENTS = '{"p_three": "3"}' CODE = "code.py" CODE_CONTENTS = ( "import sys\nimport shutil\n" @@ -91,8 +87,6 @@ def setUp(self): self._pushd(self._root_dir) self.create(self.FOO, self.FOO_CONTENTS) self.create(self.BAR, self.BAR_CONTENTS) - self.create(self.PARAMSDEFAULT, self.PARAMSDEFAULT_CONTENTS) - self.create(self.PARAMS, self.PARAMS_CONTENTS) self.create(self.CODE, self.CODE_CONTENTS) os.mkdir(self.DATA_DIR) os.mkdir(self.DATA_SUB_DIR) diff --git a/tests/func/test_run.py b/tests/func/test_run.py index 8e1bc195ce..0f97672a67 100644 --- a/tests/func/test_run.py +++ b/tests/func/test_run.py @@ -35,7 +35,6 @@ class TestRun(TestDvc): def test(self): cmd = "python {} {} {}".format(self.CODE, self.FOO, "out") deps = [self.FOO, self.CODE] - params = ["p_one", "p_two", "par.json:p_three"] outs = [os.path.join(self.dvc.root_dir, "out")] outs_no_cache = [] fname = "out.dvc" @@ -46,7 +45,6 @@ def test(self): cmd=cmd, deps=deps, outs=outs, - params=params, outs_no_cache=outs_no_cache, fname=fname, cwd=cwd, @@ -55,7 +53,7 @@ def test(self): self.assertTrue(filecmp.cmp(self.FOO, "out", shallow=False)) self.assertTrue(os.path.isfile(stage.path)) self.assertEqual(stage.cmd, cmd) - self.assertEqual(len(stage.deps), len(deps) + 2) + self.assertEqual(len(stage.deps), len(deps)) self.assertEqual(len(stage.outs), len(outs + outs_no_cache)) self.assertEqual(stage.outs[0].fspath, outs[0]) self.assertEqual(stage.outs[0].checksum, file_md5(self.FOO)[0]) diff --git a/tests/unit/command/test_run.py b/tests/unit/command/test_run.py index f11b8e7a18..2b6dcfe659 100644 --- a/tests/unit/command/test_run.py +++ b/tests/unit/command/test_run.py @@ -33,6 +33,10 @@ def test_run(mocker, dvc): "--outs-persist-no-cache", "outs-persist-no-cache", "--always-changed", + "--params", + "file:param1,param2", + "--params", + "param3", "command", ] ) @@ -51,6 +55,7 @@ def test_run(mocker, dvc): metrics_no_cache=["metrics-no-cache"], outs_persist=["outs-persist"], outs_persist_no_cache=["outs-persist-no-cache"], + params=["file:param1,param2", "param3"], fname="file", cwd="cwd", wdir="wdir", @@ -77,6 +82,7 @@ def test_run_args_from_cli(mocker, dvc): metrics_no_cache=[], outs_persist=[], outs_persist_no_cache=[], + params=[], fname=None, cwd=None, wdir=None, @@ -103,6 +109,7 @@ def test_run_args_with_spaces(mocker, dvc): metrics_no_cache=[], outs_persist=[], outs_persist_no_cache=[], + params=[], fname=None, cwd=None, wdir=None, diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index 676e6c856c..15fcfae9fc 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -1,16 +1,73 @@ -from dvc.dependency import DependencyPARAMS +import pytest + +from dvc.dependency import DependencyPARAMS, loads_params, loadd_from +from dvc.dependency.param import BadParamFileError, MissingParamsError from dvc.stage import Stage -from tests.basic_env import TestDvc - - -class TestDependencyPARAM(TestDvc): - def test_from_list(self): - stage = Stage(self.dvc) - deps = DependencyPARAMS.from_list( - stage, ["foo", "bar,baz", "a_file:qux"] - ) - assert len(deps) == 2 - assert deps[0].def_path == "a_file" - assert deps[0].param_names == ["qux"] - assert deps[1].def_path == DependencyPARAMS.DEFAULT_PARAMS_FILE - assert deps[1].param_names == ["bar", "baz", "foo"] + + +PARAMS = { + "foo": 1, + "bar": 53.135, + "baz": "str", + "qux": None, +} + + +def test_loads_params(dvc): + stage = Stage(dvc) + deps = loads_params(stage, ["foo", "bar,baz", "a_file:qux"]) + assert len(deps) == 2 + + assert isinstance(deps[0], DependencyPARAMS) + assert deps[0].def_path == DependencyPARAMS.DEFAULT_PARAMS_FILE + assert deps[0].params == ["foo", "bar", "baz"] + assert deps[0].info == {} + + assert isinstance(deps[1], DependencyPARAMS) + assert deps[1].def_path == "a_file" + assert deps[1].params == ["qux"] + assert deps[1].info == {} + + +def test_loadd_from(dvc): + stage = Stage(dvc) + deps = loadd_from(stage, [{"params": PARAMS}]) + assert len(deps) == 1 + assert isinstance(deps[0], DependencyPARAMS) + assert deps[0].def_path == DependencyPARAMS.DEFAULT_PARAMS_FILE + assert deps[0].params == list(PARAMS.keys()) + assert deps[0].info == PARAMS + + +def test_dumpd_with_info(dvc): + dep = DependencyPARAMS(Stage(dvc), None, PARAMS) + assert dep.dumpd() == { + "path": "params.yaml", + "params": PARAMS, + } + + +def test_dumpd_without_info(dvc): + dep = DependencyPARAMS(Stage(dvc), None, list(PARAMS.keys())) + assert dep.dumpd() == { + "path": "params.yaml", + "params": list(PARAMS.keys()), + } + + +def test_get_info_nonexistent_file(dvc): + dep = DependencyPARAMS(Stage(dvc), None, ["foo"]) + assert dep._get_info() == {} + + +def test_get_info_unsupported_format(tmp_dir, dvc): + tmp_dir.gen("params.yaml", b"\0\1\2\3\4\5\6\7") + dep = DependencyPARAMS(Stage(dvc), None, ["foo"]) + with pytest.raises(BadParamFileError): + dep._get_info() + + +def test_save_info_missing_params(dvc): + dep = DependencyPARAMS(Stage(dvc), None, ["foo"]) + with pytest.raises(MissingParamsError): + dep.save_info() From be21ad7838820d954c61dc0481afa52c614ce844 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Tue, 24 Mar 2020 14:47:23 +0200 Subject: [PATCH 10/11] params: fix exc formatting --- dvc/dependency/param.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 94ba54b48e..6d00f45314 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -95,7 +95,7 @@ def save_info(self): if missing_params: raise MissingParamsError( "Parameters '{}' are missing from '{}'.".format( - missing_params, self, + ", ".join(missing_params), self, ) ) From 4ec026791ca4c61cf0df2466b7ef65ecfd17f60b Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Tue, 24 Mar 2020 15:10:19 +0200 Subject: [PATCH 11/11] params: support basic nested configs --- dvc/dependency/param.py | 3 +++ tests/unit/dependency/test_params.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/dvc/dependency/param.py b/dvc/dependency/param.py index 6d00f45314..c5cd4a2b34 100644 --- a/dvc/dependency/param.py +++ b/dvc/dependency/param.py @@ -4,6 +4,7 @@ from voluptuous import Any from funcy import select_keys +from flatten_json import flatten from dvc.compat import fspath_py35 from dvc.dependency.local import DependencyLOCAL @@ -86,6 +87,8 @@ def _get_info(self): "Unable to read parameters from '{}'".format(self) ) from exc + config = flatten(config, ".") + return select_keys(lambda key: key in self.params, config) def save_info(self): diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index 15fcfae9fc..72d82d099a 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -1,3 +1,5 @@ +import yaml + import pytest from dvc.dependency import DependencyPARAMS, loads_params, loadd_from @@ -67,6 +69,12 @@ def test_get_info_unsupported_format(tmp_dir, dvc): dep._get_info() +def test_get_info_nested(tmp_dir, dvc): + tmp_dir.gen("params.yaml", yaml.dump({"some": {"path": {"foo": "val"}}})) + dep = DependencyPARAMS(Stage(dvc), None, ["some.path.foo"]) + assert dep._get_info() == {"some.path.foo": "val"} + + def test_save_info_missing_params(dvc): dep = DependencyPARAMS(Stage(dvc), None, ["foo"]) with pytest.raises(MissingParamsError):