diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index 555cd71b67..aacfb42222 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -8,6 +8,7 @@ from dvc.exceptions import DvcException from dvc.parsing import DataResolver +from dvc.path_info import PathInfo from dvc.stage import serialize from dvc.stage.exceptions import ( StageFileBadNameError, @@ -231,7 +232,9 @@ def stages(self): if self.repo.config["feature"]["parametrization"]: with log_durations(logger.debug, "resolving values"): - resolver = DataResolver(data) + resolver = DataResolver( + self.repo, PathInfo(self.path).parent, data + ) data = resolver.resolve() lockfile_data = self._lockfile.load() diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index b24345cfa9..1e5168dde5 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -1,27 +1,105 @@ import logging +import os +from copy import deepcopy from itertools import starmap +from typing import TYPE_CHECKING -from funcy import join +from funcy import first, join + +from dvc.dependency.param import ParamsDependency +from dvc.path_info import PathInfo +from dvc.utils.serialize import dumps_yaml from .context import Context from .interpolate import resolve +if TYPE_CHECKING: + from dvc.repo import Repo + logger = logging.getLogger(__name__) -STAGES = "stages" +STAGES_KWD = "stages" +USE_KWD = "use" +VARS_KWD = "vars" +WDIR_KWD = "wdir" +DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE +PARAMS_KWD = "params" class DataResolver: - def __init__(self, d): - self.context = Context() - self.data = d + def __init__(self, repo: "Repo", wdir: PathInfo, d: dict): + to_import: PathInfo = wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE) + vars_ = d.get(VARS_KWD, {}) + if os.path.exists(to_import): + self.global_ctx_source = to_import + self.global_ctx = Context.load_from(repo.tree, str(to_import)) + else: + self.global_ctx = Context() + self.global_ctx_source = None + logger.debug( + "%s does not exist, it won't be used in parametrization", + to_import, + ) - def _resolve_entry(self, name, definition): - stage_d = resolve(definition, self.context) - logger.trace("Resolved stage data for '%s': %s", name, stage_d) - return {name: stage_d} + self.global_ctx.merge_update(vars_) + self.data: dict = d + self.wdir = wdir + self.repo = repo + + def _resolve_entry(self, name: str, definition): + context = Context.clone(self.global_ctx) + return self._resolve_stage(context, name, definition) def resolve(self): - stages = self.data.get(STAGES, {}) + stages = self.data.get(STAGES_KWD, {}) data = join(starmap(self._resolve_entry, stages.items())) - return {**self.data, STAGES: data} + logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data)) + return {STAGES_KWD: data} + + def _resolve_stage(self, context: Context, name: str, definition) -> dict: + definition = deepcopy(definition) + wdir = self._resolve_wdir(context, definition.get(WDIR_KWD)) + if self.wdir != wdir: + logger.debug( + "Stage %s has different wdir than dvc.yaml file", name + ) + + contexts = [] + params_yaml_file = wdir / DEFAULT_PARAMS_FILE + if self.global_ctx_source != params_yaml_file: + if os.path.exists(params_yaml_file): + contexts.append( + Context.load_from(self.repo.tree, str(params_yaml_file)) + ) + else: + logger.debug( + "%s does not exist for stage %s", params_yaml_file, name + ) + + params_file = definition.get(PARAMS_KWD, []) + for item in params_file: + if item and isinstance(item, dict): + contexts.append( + Context.load_from(self.repo.tree, str(wdir / first(item))) + ) + + context.merge_update(*contexts) + + logger.trace( # pytype: disable=attribute-error + "Context during resolution of stage %s:\n%s", name, context + ) + + with context.track(): + stage_d = resolve(definition, context) + + params = stage_d.get(PARAMS_KWD, []) + context.tracked + + if params: + stage_d[PARAMS_KWD] = params + return {name: stage_d} + + def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo: + if not wdir: + return self.wdir + wdir = resolve(wdir, context) + return self.wdir / str(wdir) diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index 3fdbc2a78e..8211f5d885 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -1,61 +1,208 @@ -from collections.abc import Collection, Mapping, Sequence +import os +from collections import defaultdict +from collections.abc import Mapping, MutableMapping, MutableSequence +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, field, replace +from typing import Any, List, Optional, Sequence, Union -# for testing purpose -# FIXME: after implementing of reading of "params". -TEST_DATA = { - "__test__": { - "dict": {"one": 1, "two": 2, "three": "three", "four": "4"}, - "list": [1, 2, 3, 4, 3.14], - "set": {1, 2, 3}, - "tuple": (1, 2), - "bool": True, - "none": None, - "float": 3.14, - "nomnom": 1000, - } -} +from funcy import identity +from dvc.utils.serialize import LOADERS -class Context: - def __init__(self, data=None): - self.data = data or TEST_DATA - def select(self, key): - return _get_value(self.data, key) +def _merge(into, update, overwrite): + for key, val in update.items(): + if isinstance(into.get(key), Mapping) and isinstance(val, Mapping): + _merge(into[key], val, overwrite) + else: + if key in into and not overwrite: + raise ValueError( + f"Cannot overwrite as key {key} already exists in {into}" + ) + into[key] = val -def _get_item(data, idx): - if isinstance(data, Sequence): - idx = int(idx) +@dataclass +class Meta: + source: Optional[str] = None + dpaths: List[str] = field(default_factory=list) - if isinstance(data, (Mapping, Sequence)): - return data[idx] + @staticmethod + def update_path(meta: "Meta", path: Union[str, int]): + dpaths = meta.dpaths[:] + [str(path)] + return replace(meta, dpaths=dpaths) - raise ValueError( - f"Cannot get item '{idx}' from data of type '{type(data).__name__}'" + def __str__(self): + string = self.source or ":" + string += self.path() + return string + + def path(self): + return ".".join(self.dpaths) + + +def _default_meta(): + return Meta(source=None) + + +@dataclass +class Value: + value: Any + meta: Meta = field( + compare=False, default_factory=_default_meta, repr=False ) + def __repr__(self): + return repr(self.value) -def _get_value(data, key): - obj_and_attrs = key.strip().split(".") - value = data - for attr in obj_and_attrs: - if attr == "": - raise ValueError("Syntax error!") + def get_sources(self): + return {self.meta.source: self.meta.path()} - try: - value = _get_item(value, attr) - except KeyError: + +class Container: + meta: Meta + data: Union[list, dict] + _key_transform = staticmethod(identity) + + def __init__(self, meta) -> None: + self.meta = meta or Meta(source=None) + + def _convert(self, key, value): + meta = Meta.update_path(self.meta, key) + if value is None or isinstance(value, (int, float, str, bytes, bool)): + return Value(value, meta=meta) + elif isinstance(value, (CtxList, CtxDict, Value)): + return value + elif isinstance(value, (list, dict)): + container = CtxDict if isinstance(value, dict) else CtxList + return container(value, meta=meta) + else: msg = ( - f"Could not find '{attr}' " - "while substituting " - f"'{key}'.\n" - f"Interpolating with: {data}" + "Unsupported value of type " + f"'{type(value).__name__}' in '{meta}'" ) - raise ValueError(msg) + raise TypeError(msg) + + def __repr__(self): + return repr(self.data) + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = self._convert(key, value) + + def __delitem__(self, key): + del self.data[key] + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def __eq__(self, o): + return o.data == self.data + + def select(self, key: str): + index, *rems = key.split(sep=".", maxsplit=1) + index = index.strip() + index = self._key_transform(index) + try: + d = self.data[index] + except LookupError as exc: + raise ValueError( + f"Could not find '{index}' in {self.data}" + ) from exc + return d.select(rems[0]) if rems else d + + def get_sources(self): + return {} + + +class CtxList(Container, MutableSequence): + _key_transform = staticmethod(int) + + def __init__(self, values: Sequence, meta: Meta = None): + super().__init__(meta=meta) + self.data: list = [] + self.extend(values) + + def insert(self, index: int, value): + self.data.insert(index, self._convert(index, value)) + + def get_sources(self): + return {self.meta.source: self.meta.path()} + + +class CtxDict(Container, MutableMapping): + def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs): + super().__init__(meta=meta) + + self.data: dict = {} + if mapping: + self.update(mapping) + self.update(kwargs) + + def __setitem__(self, key, value): + if not isinstance(key, str): + # limitation for the interpolation + # ignore other kinds of keys + return + return super().__setitem__(key, value) + + def merge_update(self, *args, overwrite=False): + for d in args: + _merge(self.data, d, overwrite=overwrite) + + +class Context(CtxDict): + def __init__(self, *args, **kwargs): + """ + Top level mutable dict, with some helpers to create context and track + """ + super().__init__(*args, **kwargs) + self._track = False + self._tracked_data = defaultdict(set) + + @contextmanager + def track(self): + self._track = True + yield + self._track = False + + def _track_data(self, node): + if not self._track: + return + + for source, keys in node.get_sources().items(): + if not source: + continue + params_file = self._tracked_data[source] + keys = [keys] if isinstance(keys, str) else keys + params_file.update(keys) + + @property + def tracked(self): + return [ + {file: list(keys)} for file, keys in self._tracked_data.items() + ] + + def select(self, key: str): + node = super().select(key) + self._track_data(node) + return node + + @classmethod + def load_from(cls, tree, file: str) -> "Context": + _, ext = os.path.splitext(file) + loader = LOADERS[ext] + + meta = Meta(source=file) + return cls(loader(file, tree=tree), meta=meta) - if not isinstance(value, str) and isinstance(value, Collection): - raise ValueError( - f"Cannot interpolate value of type '{type(value).__name__}'" - ) - return value + @classmethod + def clone(cls, ctx: "Context") -> "Context": + """Clones given context.""" + return cls(deepcopy(ctx.data)) diff --git a/dvc/parsing/interpolate.py b/dvc/parsing/interpolate.py index 56e72f31d9..124fb1c0ac 100644 --- a/dvc/parsing/interpolate.py +++ b/dvc/parsing/interpolate.py @@ -3,6 +3,8 @@ from funcy import rpartial +from dvc.parsing.context import Context, Value + KEYCRE = re.compile( r""" (?