Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dvc.exceptions import DvcException
from dvc.parsing import DataResolver
from dvc.path_info import PathInfo
from dvc.stage import serialize
from dvc.stage.exceptions import (
StageFileBadNameError,
Expand Down Expand Up @@ -231,7 +232,9 @@ def stages(self):

if self.repo.config["feature"]["parametrization"]:
with log_durations(logger.debug, "resolving values"):
resolver = DataResolver(data)
resolver = DataResolver(
self.repo, PathInfo(self.path).parent, data
)
data = resolver.resolve()

lockfile_data = self._lockfile.load()
Expand Down
100 changes: 89 additions & 11 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,105 @@
import logging
import os
from copy import deepcopy
from itertools import starmap
from typing import TYPE_CHECKING

from funcy import join
from funcy import first, join

from dvc.dependency.param import ParamsDependency
from dvc.path_info import PathInfo
from dvc.utils.serialize import dumps_yaml

from .context import Context
from .interpolate import resolve

if TYPE_CHECKING:
from dvc.repo import Repo

logger = logging.getLogger(__name__)

STAGES = "stages"
STAGES_KWD = "stages"
USE_KWD = "use"
VARS_KWD = "vars"
WDIR_KWD = "wdir"
DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
PARAMS_KWD = "params"


class DataResolver:
def __init__(self, d):
self.context = Context()
self.data = d
def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):
to_import: PathInfo = wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE)
vars_ = d.get(VARS_KWD, {})
if os.path.exists(to_import):
self.global_ctx_source = to_import
self.global_ctx = Context.load_from(repo.tree, str(to_import))
else:
self.global_ctx = Context()
self.global_ctx_source = None
logger.debug(
"%s does not exist, it won't be used in parametrization",
to_import,
)

def _resolve_entry(self, name, definition):
stage_d = resolve(definition, self.context)
logger.trace("Resolved stage data for '%s': %s", name, stage_d)
return {name: stage_d}
self.global_ctx.merge_update(vars_)
self.data: dict = d
self.wdir = wdir
self.repo = repo

def _resolve_entry(self, name: str, definition):
context = Context.clone(self.global_ctx)
return self._resolve_stage(context, name, definition)

def resolve(self):
stages = self.data.get(STAGES, {})
stages = self.data.get(STAGES_KWD, {})
data = join(starmap(self._resolve_entry, stages.items()))
return {**self.data, STAGES: data}
logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data))
return {STAGES_KWD: data}

def _resolve_stage(self, context: Context, name: str, definition) -> dict:
definition = deepcopy(definition)
wdir = self._resolve_wdir(context, definition.get(WDIR_KWD))
if self.wdir != wdir:
logger.debug(
"Stage %s has different wdir than dvc.yaml file", name
)

contexts = []
params_yaml_file = wdir / DEFAULT_PARAMS_FILE
if self.global_ctx_source != params_yaml_file:
if os.path.exists(params_yaml_file):
contexts.append(
Context.load_from(self.repo.tree, str(params_yaml_file))
)
else:
logger.debug(
"%s does not exist for stage %s", params_yaml_file, name
)

params_file = definition.get(PARAMS_KWD, [])
for item in params_file:
if item and isinstance(item, dict):
contexts.append(
Context.load_from(self.repo.tree, str(wdir / first(item)))
)

context.merge_update(*contexts)

logger.trace( # pytype: disable=attribute-error
"Context during resolution of stage %s:\n%s", name, context
)

with context.track():
stage_d = resolve(definition, context)

params = stage_d.get(PARAMS_KWD, []) + context.tracked

if params:
stage_d[PARAMS_KWD] = params
return {name: stage_d}

def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo:
if not wdir:
return self.wdir
wdir = resolve(wdir, context)
return self.wdir / str(wdir)
239 changes: 193 additions & 46 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,208 @@
from collections.abc import Collection, Mapping, Sequence
import os
from collections import defaultdict
from collections.abc import Mapping, MutableMapping, MutableSequence
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional, Sequence, Union

# for testing purpose
# FIXME: after implementing of reading of "params".
TEST_DATA = {
"__test__": {
"dict": {"one": 1, "two": 2, "three": "three", "four": "4"},
"list": [1, 2, 3, 4, 3.14],
"set": {1, 2, 3},
"tuple": (1, 2),
"bool": True,
"none": None,
"float": 3.14,
"nomnom": 1000,
}
}
from funcy import identity

from dvc.utils.serialize import LOADERS

class Context:
def __init__(self, data=None):
self.data = data or TEST_DATA

def select(self, key):
return _get_value(self.data, key)
def _merge(into, update, overwrite):
for key, val in update.items():
if isinstance(into.get(key), Mapping) and isinstance(val, Mapping):
_merge(into[key], val, overwrite)
else:
if key in into and not overwrite:
raise ValueError(
f"Cannot overwrite as key {key} already exists in {into}"
)
into[key] = val


def _get_item(data, idx):
if isinstance(data, Sequence):
idx = int(idx)
@dataclass
class Meta:
source: Optional[str] = None
dpaths: List[str] = field(default_factory=list)

if isinstance(data, (Mapping, Sequence)):
return data[idx]
@staticmethod
def update_path(meta: "Meta", path: Union[str, int]):
dpaths = meta.dpaths[:] + [str(path)]
return replace(meta, dpaths=dpaths)

raise ValueError(
f"Cannot get item '{idx}' from data of type '{type(data).__name__}'"
def __str__(self):
string = self.source or "<local>:"
string += self.path()
return string

def path(self):
return ".".join(self.dpaths)


def _default_meta():
return Meta(source=None)


@dataclass
class Value:
value: Any
meta: Meta = field(
compare=False, default_factory=_default_meta, repr=False
)

def __repr__(self):
return repr(self.value)

def _get_value(data, key):
obj_and_attrs = key.strip().split(".")
value = data
for attr in obj_and_attrs:
if attr == "":
raise ValueError("Syntax error!")
def get_sources(self):
return {self.meta.source: self.meta.path()}

try:
value = _get_item(value, attr)
except KeyError:

class Container:
meta: Meta
data: Union[list, dict]
_key_transform = staticmethod(identity)

def __init__(self, meta) -> None:
self.meta = meta or Meta(source=None)

def _convert(self, key, value):
meta = Meta.update_path(self.meta, key)
if value is None or isinstance(value, (int, float, str, bytes, bool)):
return Value(value, meta=meta)
elif isinstance(value, (CtxList, CtxDict, Value)):
return value
elif isinstance(value, (list, dict)):
container = CtxDict if isinstance(value, dict) else CtxList
return container(value, meta=meta)
else:
msg = (
f"Could not find '{attr}' "
"while substituting "
f"'{key}'.\n"
f"Interpolating with: {data}"
"Unsupported value of type "
f"'{type(value).__name__}' in '{meta}'"
)
raise ValueError(msg)
raise TypeError(msg)

def __repr__(self):
return repr(self.data)

def __getitem__(self, key):
return self.data[key]

def __setitem__(self, key, value):
self.data[key] = self._convert(key, value)

def __delitem__(self, key):
del self.data[key]

def __len__(self):
return len(self.data)

def __iter__(self):
return iter(self.data)

def __eq__(self, o):
return o.data == self.data

def select(self, key: str):
index, *rems = key.split(sep=".", maxsplit=1)
index = index.strip()
index = self._key_transform(index)
try:
d = self.data[index]
except LookupError as exc:
raise ValueError(
f"Could not find '{index}' in {self.data}"
) from exc
return d.select(rems[0]) if rems else d

def get_sources(self):
return {}


class CtxList(Container, MutableSequence):
_key_transform = staticmethod(int)

def __init__(self, values: Sequence, meta: Meta = None):
super().__init__(meta=meta)
self.data: list = []
self.extend(values)

def insert(self, index: int, value):
self.data.insert(index, self._convert(index, value))

def get_sources(self):
return {self.meta.source: self.meta.path()}


class CtxDict(Container, MutableMapping):
def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs):
super().__init__(meta=meta)

self.data: dict = {}
if mapping:
self.update(mapping)
self.update(kwargs)

def __setitem__(self, key, value):
if not isinstance(key, str):
# limitation for the interpolation
# ignore other kinds of keys
return
return super().__setitem__(key, value)

def merge_update(self, *args, overwrite=False):
for d in args:
_merge(self.data, d, overwrite=overwrite)


class Context(CtxDict):
def __init__(self, *args, **kwargs):
"""
Top level mutable dict, with some helpers to create context and track
"""
super().__init__(*args, **kwargs)
self._track = False
self._tracked_data = defaultdict(set)

@contextmanager
def track(self):
self._track = True
yield
self._track = False

def _track_data(self, node):
if not self._track:
return

for source, keys in node.get_sources().items():
if not source:
continue
params_file = self._tracked_data[source]
keys = [keys] if isinstance(keys, str) else keys
params_file.update(keys)

@property
def tracked(self):
return [
{file: list(keys)} for file, keys in self._tracked_data.items()
]

def select(self, key: str):
node = super().select(key)
self._track_data(node)
return node

@classmethod
def load_from(cls, tree, file: str) -> "Context":
_, ext = os.path.splitext(file)
loader = LOADERS[ext]

meta = Meta(source=file)
return cls(loader(file, tree=tree), meta=meta)

if not isinstance(value, str) and isinstance(value, Collection):
raise ValueError(
f"Cannot interpolate value of type '{type(value).__name__}'"
)
return value
@classmethod
def clone(cls, ctx: "Context") -> "Context":
"""Clones given context."""
return cls(deepcopy(ctx.data))
Loading