diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index 7b11d34734..6113906591 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -4,16 +4,21 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from itertools import starmap -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from funcy import first, join from dvc.dependency.param import ParamsDependency from dvc.path_info import PathInfo -from dvc.utils.serialize import dumps_yaml from .context import Context -from .interpolate import resolve +from .interpolate import ( + _get_matches, + _is_exact_string, + _is_interpolated_string, + _resolve_str, + resolve, +) if TYPE_CHECKING: from dvc.repo import Repo @@ -28,8 +33,10 @@ PARAMS_KWD = "params" FOREACH_KWD = "foreach" IN_KWD = "in" +SET_KWD = "set" DEFAULT_SENTINEL = object() +SeqOrMap = Union[Sequence, Mapping] class DataResolver: @@ -56,6 +63,7 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict): def _resolve_entry(self, name: str, definition): context = Context.clone(self.global_ctx) if FOREACH_KWD in definition: + self.set_context_from(context, definition.get(SET_KWD, {})) assert IN_KWD in definition return self._foreach( context, name, definition[FOREACH_KWD], definition[IN_KWD] @@ -65,11 +73,12 @@ def _resolve_entry(self, name: str, definition): def resolve(self): stages = self.data.get(STAGES_KWD, {}) data = join(starmap(self._resolve_entry, stages.items())) - logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data)) + logger.trace("Resolved dvc.yaml:\n%s", data) return {STAGES_KWD: data} def _resolve_stage(self, context: Context, name: str, definition) -> dict: definition = deepcopy(definition) + self.set_context_from(context, definition.pop(SET_KWD, {})) wdir = self._resolve_wdir(context, definition.get(WDIR_KWD)) if self.wdir != wdir: logger.debug( @@ -135,10 +144,56 @@ def each_iter(value, key=DEFAULT_SENTINEL): return self._resolve_stage(c, f"{name}-{suffix}", in_data) iterable = resolve(foreach_data, context) + + assert isinstance(iterable, (Sequence, Mapping)) and not isinstance( + iterable, str + ), f"got type of {type(iterable)}" if isinstance(iterable, Sequence): gen = (each_iter(v) for v in iterable) - elif isinstance(iterable, Mapping): - gen = (each_iter(v, k) for k, v in iterable.items()) else: - raise Exception(f"got type of {type(iterable)}") + gen = (each_iter(v, k) for k, v in iterable.items()) return join(gen) + + @classmethod + def set_context_from(cls, context: Context, to_set): + for key, value in to_set.items(): + if key in context: + raise ValueError(f"Cannot set '{key}', key already exists") + if isinstance(value, str): + cls._check_joined_with_interpolation(key, value) + value = _resolve_str(value, context, unwrap=False) + elif isinstance(value, (Sequence, Mapping)): + cls._check_nested_collection(key, value) + cls._check_interpolation_collection(key, value) + context[key] = value + + @staticmethod + def _check_nested_collection(key: str, value: SeqOrMap): + values = value.values() if isinstance(value, Mapping) else value + has_nested = any( + not isinstance(item, str) and isinstance(item, (Mapping, Sequence)) + for item in values + ) + if has_nested: + raise ValueError(f"Cannot set '{key}', has nested dict/list") + + @staticmethod + def _check_interpolation_collection(key: str, value: SeqOrMap): + values = value.values() if isinstance(value, Mapping) else value + interpolated = any(_is_interpolated_string(item) for item in values) + if interpolated: + raise ValueError( + f"Cannot set '{key}', " + "having interpolation inside " + f"'{type(value).__name__}' is not supported." + ) + + @staticmethod + def _check_joined_with_interpolation(key: str, value: str): + matches = _get_matches(value) + if matches and not _is_exact_string(value, matches): + raise ValueError( + f"Cannot set '{key}', " + "joining string with interpolated string" + "is not supported" + ) diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index 5817c46b28..fd192fcee3 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -56,6 +56,9 @@ class Value: def __repr__(self): return repr(self.value) + def __str__(self) -> str: + return str(self.value) + def get_sources(self): return {self.meta.source: self.meta.path()} @@ -103,7 +106,10 @@ def __iter__(self): return iter(self.data) def __eq__(self, o): - return o.data == self.data + container = type(self) + if isinstance(o, container): + return o.data == self.data + return container(o) == self def select(self, key: str): index, *rems = key.split(sep=".", maxsplit=1) diff --git a/dvc/parsing/interpolate.py b/dvc/parsing/interpolate.py index 124fb1c0ac..51160fec73 100644 --- a/dvc/parsing/interpolate.py +++ b/dvc/parsing/interpolate.py @@ -16,21 +16,27 @@ re.VERBOSE, ) +UNWRAP_DEFAULT = True -def _get_matches(template): + +def _get_matches(template: str): return list(KEYCRE.finditer(template)) +def _is_interpolated_string(val): + return bool(_get_matches(val)) if isinstance(val, str) else False + + def _unwrap(value): if isinstance(value, Value): return value.value return value -def _resolve_value(match, context: Context): +def _resolve_value(match, context: Context, unwrap=UNWRAP_DEFAULT): _, _, inner = match.groups() value = context.select(inner) - return _unwrap(value) + return _unwrap(value) if unwrap else value def _str_interpolate(template, matches, context): @@ -42,12 +48,16 @@ def _str_interpolate(template, matches, context): return buf + template[index:] -def _resolve_str(src: str, context): +def _is_exact_string(src: str, matches): + return len(matches) == 1 and src == matches[0].group(0) + + +def _resolve_str(src: str, context, unwrap=UNWRAP_DEFAULT): matches = _get_matches(src) - if len(matches) == 1 and src == matches[0].group(0): + if _is_exact_string(src, matches): # replace "${enabled}", if `enabled` is a boolean, with it's actual # value rather than it's string counterparts. - return _resolve_value(matches[0], context) + return _resolve_value(matches[0], context, unwrap=unwrap) # but not "${num} days" src = _str_interpolate(src, matches, context) diff --git a/dvc/schema.py b/dvc/schema.py index c27c54d48e..f11f37c13d 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -2,7 +2,7 @@ from dvc import dependency, output from dvc.output import CHECKSUMS_SCHEMA, BaseOutput -from dvc.parsing import FOREACH_KWD, IN_KWD, USE_KWD, VARS_KWD +from dvc.parsing import FOREACH_KWD, IN_KWD, SET_KWD, USE_KWD, VARS_KWD from dvc.stage.params import StageParams STAGES = "stages" @@ -50,6 +50,7 @@ STAGE_DEFINITION = { StageParams.PARAM_CMD: str, + Optional(SET_KWD): dict, Optional(StageParams.PARAM_WDIR): str, Optional(StageParams.PARAM_DEPS): [str], Optional(StageParams.PARAM_PARAMS): [ @@ -66,6 +67,7 @@ } FOREACH_IN = { + Optional(SET_KWD): dict, Required(FOREACH_KWD): Any(dict, list, str), Required(IN_KWD): STAGE_DEFINITION, } diff --git a/tests/func/test_stage_resolver.py b/tests/func/test_stage_resolver.py index c4b9176b74..c91d8d0028 100644 --- a/tests/func/test_stage_resolver.py +++ b/tests/func/test_stage_resolver.py @@ -1,5 +1,6 @@ import os from copy import deepcopy +from math import pi import pytest @@ -302,3 +303,100 @@ def test_foreach_loop_templatized(tmp_dir, dvc): } }, ) + + +@pytest.mark.parametrize( + "value", ["value", "To set or not to set", 3, pi, True, False, None] +) +def test_set(tmp_dir, dvc, value): + d = { + "stages": { + "build": { + "set": {"item": value}, + "cmd": "python script.py --thresh ${item}", + "always_changed": "${item}", + } + } + } + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + assert resolver.resolve() == { + "stages": { + "build": { + "cmd": f"python script.py --thresh {value}", + "always_changed": value, + } + } + } + + +@pytest.mark.parametrize( + "coll", [["foo", "bar", "baz"], {"foo": "foo", "bar": "bar"}] +) +def test_coll(tmp_dir, dvc, coll): + d = { + "stages": { + "build": { + "set": {"item": coll, "thresh": 10}, + "cmd": "python script.py --thresh ${thresh}", + "outs": "${item}", + } + } + } + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + assert resolver.resolve() == { + "stages": { + "build": {"cmd": "python script.py --thresh 10", "outs": coll} + } + } + + +def test_set_with_foreach(tmp_dir, dvc): + items = ["foo", "bar", "baz"] + d = { + "stages": { + "build": { + "set": {"items": items}, + "foreach": "${items}", + "in": {"cmd": "command --value ${item}"}, + } + } + } + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + assert resolver.resolve() == { + "stages": { + f"build-{item}": {"cmd": f"command --value {item}"} + for item in items + } + } + + +def test_set_with_foreach_and_on_stage_definition(tmp_dir, dvc): + iterable = {"models": {"us": {"thresh": 10}, "gb": {"thresh": 15}}} + dump_json(tmp_dir / "params.json", iterable) + + d = { + "use": "params.json", + "stages": { + "build": { + "set": {"data": "${models}"}, + "foreach": "${data}", + "in": { + "set": {"thresh": "${item.thresh}"}, + "cmd": "command --value ${thresh}", + }, + } + }, + } + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + assert resolver.resolve() == { + "stages": { + "build-us": { + "cmd": "command --value 10", + "params": [{"params.json": ["models.us.thresh"]}], + }, + "build-gb": { + "cmd": "command --value 15", + "params": [{"params.json": ["models.gb.thresh"]}], + }, + } + } diff --git a/tests/unit/test_stage_resolver.py b/tests/unit/test_stage_resolver.py index 808c7863ca..14c50cc2ed 100644 --- a/tests/unit/test_stage_resolver.py +++ b/tests/unit/test_stage_resolver.py @@ -1,5 +1,9 @@ +from math import pi + +import pytest + from dvc.parsing import DataResolver -from dvc.parsing.context import Context +from dvc.parsing.context import Context, Value TEMPLATED_DVC_YAML_DATA = { "stages": { @@ -38,3 +42,134 @@ def test_resolver(tmp_dir, dvc): resolver = DataResolver(dvc, tmp_dir, TEMPLATED_DVC_YAML_DATA) resolver.global_ctx = Context(CONTEXT_DATA) assert resolver.resolve() == RESOLVED_DVC_YAML_DATA + + +def test_set(): + context = Context(CONTEXT_DATA) + to_set = { + "foo": "foo", + "bar": "bar", + "pi": pi, + "true": True, + "false": False, + "none": "None", + "int": 1, + "lst2": [1, 2, 3], + "dct2": {"foo": "bar", "foobar": "foobar"}, + } + DataResolver.set_context_from(context, to_set) + + for key, value in to_set.items(): + # FIXME: using for convenience, figure out better way to do it + assert context[key] == context._convert(key, value) + + +@pytest.mark.parametrize( + "coll", + [ + ["foo", "bar", ["foo", "bar"]], + ["foo", "bar", {"foo": "foo", "bar": "bar"}], + {"foo": "foo", "bar": ["foo", "bar"]}, + {"foo": "foo", "bar": {"foo": "foo", "bar": "bar"}}, + ], +) +def test_set_nested_coll(coll): + context = Context(CONTEXT_DATA) + with pytest.raises(ValueError, match="Cannot set 'item', has nested"): + DataResolver.set_context_from(context, {"thresh": 10, "item": coll}) + + +def test_set_already_exists(): + context = Context({"item": "foo"}) + with pytest.raises( + ValueError, match="Cannot set 'item', key already exists" + ): + DataResolver.set_context_from(context, {"item": "bar"}) + + assert context["item"] == Value("foo") + + +@pytest.mark.parametrize( + "coll", [["foo", "${bar}"], {"foo": "${foo}", "bar": "bar"}], +) +def test_set_collection_interpolation(coll): + context = Context(CONTEXT_DATA) + with pytest.raises( + ValueError, match="Cannot set 'item', having interpolation inside" + ): + DataResolver.set_context_from(context, {"thresh": 10, "item": coll}) + + +def test_set_interpolated_string(): + context = Context(CONTEXT_DATA) + DataResolver.set_context_from( + context, + { + "foo": "${dict.foo}", + "bar": "${dict.bar}", + "param1": "${list.0}", + "param2": "${list.1}", + "frozen": "${freeze}", + "dict2": "${dict}", + "list2": "${list}", + }, + ) + + assert context["foo"] == Value("foo") + assert context["bar"] == Value("bar") + assert context["param1"] == Value("param1") + assert context["param2"] == Value("param2") + assert context["frozen"] == context["freeze"] == Value(True) + assert context["dict2"] == context["dict"] == CONTEXT_DATA["dict"] + assert context["list2"] == context["list"] == CONTEXT_DATA["list"] + + +def test_set_ladder(): + context = Context(CONTEXT_DATA) + DataResolver.set_context_from( + context, + { + "item": 5, + "foo": "${dict.foo}", + "bar": "${dict.bar}", + "bar2": "${bar}", + "dict2": "${dict}", + "list2": "${list}", + "dict3": "${dict2}", + "list3": "${list2}", + }, + ) + + assert context["item"] == Value(5) + assert context["foo"] == context["dict"]["foo"] == Value("foo") + assert ( + context["bar"] + == context["bar2"] + == context["dict"]["bar"] + == Value("bar") + ) + assert ( + context["dict"] + == context["dict2"] + == context["dict3"] + == CONTEXT_DATA["dict"] + ) + assert ( + context["list"] + == context["list2"] + == context["list3"] + == CONTEXT_DATA["list"] + ) + + +@pytest.mark.parametrize( + "value", + ["param ${dict.foo}", "${dict.bar}${dict.foo}", "${dict.foo}-${dict.bar}"], +) +def test_set_multiple_interpolations(value): + context = Context(CONTEXT_DATA) + with pytest.raises( + ValueError, + match=r"Cannot set 'item', joining string with interpolated string", + ): + DataResolver.set_context_from(context, {"thresh": 10, "item": value})