From 553062a5b7aa9e61179b05f51e40fb9277c474a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 16 Oct 2020 19:14:08 +0545 Subject: [PATCH 1/4] Implement importing from params --- dvc/dvcfile.py | 5 +- dvc/parsing/__init__.py | 101 +++++++++++-- dvc/parsing/context.py | 239 ++++++++++++++++++++++++------ dvc/parsing/interpolate.py | 17 ++- dvc/schema.py | 7 +- tests/unit/test_stage_resolver.py | 7 +- 6 files changed, 309 insertions(+), 67 deletions(-) diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index 555cd71b67..aacfb42222 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -8,6 +8,7 @@ from dvc.exceptions import DvcException from dvc.parsing import DataResolver +from dvc.path_info import PathInfo from dvc.stage import serialize from dvc.stage.exceptions import ( StageFileBadNameError, @@ -231,7 +232,9 @@ def stages(self): if self.repo.config["feature"]["parametrization"]: with log_durations(logger.debug, "resolving values"): - resolver = DataResolver(data) + resolver = DataResolver( + self.repo, PathInfo(self.path).parent, data + ) data = resolver.resolve() lockfile_data = self._lockfile.load() diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index b24345cfa9..2a60eda5bf 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -1,27 +1,108 @@ import logging +import os +from copy import deepcopy from itertools import starmap +from typing import TYPE_CHECKING from funcy import join +from funcy.seqs import first + +from dvc.dependency.param import ParamsDependency +from dvc.path_info import PathInfo +from dvc.utils.serialize import dumps_yaml from .context import Context from .interpolate import resolve +if TYPE_CHECKING: + from dvc.repo import Repo + logger = logging.getLogger(__name__) -STAGES = "stages" +STAGES_KWD = "stages" +USE_KWD = "use" +VARS_KWD = "vars" +WDIR_KWD = "wdir" +DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE +PARAMS_KWD = "params" + +DEFAULT_SENTINEL = object() class DataResolver: - def __init__(self, d): - self.context = Context() - self.data = d + def __init__(self, repo: "Repo", yaml_wdir: PathInfo, d: dict): + to_import: PathInfo = yaml_wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE) + vars_ = d.get(VARS_KWD, {}) + if os.path.exists(to_import): + self.global_ctx_source = to_import + self.global_ctx = Context.load_from(repo.tree, str(to_import)) + else: + self.global_ctx = Context() + self.global_ctx_source = None + logger.debug( + "%s does not exist, it won't be used in parametrization", + to_import, + ) - def _resolve_entry(self, name, definition): - stage_d = resolve(definition, self.context) - logger.trace("Resolved stage data for '%s': %s", name, stage_d) - return {name: stage_d} + self.global_ctx.merge_update(vars_) + self.data: dict = d + self._yaml_wdir = yaml_wdir + self.repo = repo + + def _resolve_entry(self, name: str, definition): + context = Context.clone(self.global_ctx) + return self._resolve_stage(context, name, definition) def resolve(self): - stages = self.data.get(STAGES, {}) + stages = self.data.get(STAGES_KWD, {}) data = join(starmap(self._resolve_entry, stages.items())) - return {**self.data, STAGES: data} + logger.trace("Resolved dvc.yaml:\n%s", dumps_yaml(data)) + return {STAGES_KWD: data} + + def _resolve_stage(self, context: Context, name: str, definition) -> dict: + definition = deepcopy(definition) + wdir = self._resolve_wdir(context, definition.get(WDIR_KWD)) + if self._yaml_wdir != wdir: + logger.debug( + "Stage %s has different wdir than dvc.yaml file", name + ) + + contexts = [] + params_yaml_file = wdir / DEFAULT_PARAMS_FILE + if self.global_ctx_source != params_yaml_file: + if os.path.exists(params_yaml_file): + contexts.append( + Context.load_from(self.repo.tree, str(params_yaml_file)) + ) + else: + logger.debug( + "%s does not exist for stage %s", params_yaml_file, name + ) + + params_file = definition.get(PARAMS_KWD, []) + for item in params_file: + if item and isinstance(item, dict): + contexts.append( + Context.load_from(self.repo.tree, str(wdir / first(item))) + ) + + context.merge_update(*contexts) + + logger.trace( # pytype: disable=attribute-error + "Context during resolution of stage %s:\n%s", name, context + ) + + with context.track(): + stage_d = resolve(definition, context) + + params = stage_d.get(PARAMS_KWD, []) + context.tracked + + if params: + stage_d[PARAMS_KWD] = params + return {name: stage_d} + + def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo: + if not wdir: + return self._yaml_wdir + wdir = resolve(wdir, context) + return self._yaml_wdir / str(wdir) diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index 3fdbc2a78e..215fade541 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -1,61 +1,202 @@ -from collections.abc import Collection, Mapping, Sequence +import os +from collections import defaultdict +from collections.abc import Mapping, MutableMapping, MutableSequence +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, field, replace +from typing import Any, List, Optional, Sequence, Union -# for testing purpose -# FIXME: after implementing of reading of "params". -TEST_DATA = { - "__test__": { - "dict": {"one": 1, "two": 2, "three": "three", "four": "4"}, - "list": [1, 2, 3, 4, 3.14], - "set": {1, 2, 3}, - "tuple": (1, 2), - "bool": True, - "none": None, - "float": 3.14, - "nomnom": 1000, - } -} +from funcy import identity +from dvc.utils.serialize import LOADERS -class Context: - def __init__(self, data=None): - self.data = data or TEST_DATA - def select(self, key): - return _get_value(self.data, key) +def _merge(into, update, overwrite): + for key, val in update.items(): + if isinstance(into.get(key), Mapping) and isinstance(val, Mapping): + _merge(into[key], val, overwrite) + else: + if key in into and not overwrite: + raise ValueError( + f"Cannot overwrite as key {key} already exists in {into}" + ) + into[key] = val -def _get_item(data, idx): - if isinstance(data, Sequence): - idx = int(idx) +@dataclass +class Meta: + source: Optional[str] + dpaths: List[str] = field(default_factory=list) - if isinstance(data, (Mapping, Sequence)): - return data[idx] + @staticmethod + def update_path(meta: "Meta", path: Union[str, int]): + dpaths = meta.dpaths[:] + [str(path)] + return replace(meta, dpaths=dpaths) - raise ValueError( - f"Cannot get item '{idx}' from data of type '{type(data).__name__}'" - ) + def __str__(self): + string = self.source or ":" + string += self.path() + return string + def path(self): + return ".".join(self.dpaths) -def _get_value(data, key): - obj_and_attrs = key.strip().split(".") - value = data - for attr in obj_and_attrs: - if attr == "": - raise ValueError("Syntax error!") +@dataclass +class Value: + value: Any + meta: Meta = field(compare=False, repr=False) + + def __repr__(self): + return f"'{self}'" + + def __str__(self) -> str: + return str(self.value) + + def get_sources(self): + return {self.meta.source: self.meta.path()} + + +class Container: + meta: Meta + data: Union[list, dict] + _key_transform = staticmethod(identity) + + def __init__(self, meta) -> None: + self.meta = meta or Meta(source=None) + + def _convert(self, key, value): + meta = Meta.update_path(self.meta, key) + if value is None or isinstance(value, (int, float, str, bytes, bool)): + return Value(value, meta=meta) + elif isinstance(value, (CtxList, CtxDict, Value)): + return value + elif isinstance(value, (list, dict)): + container = CtxDict if isinstance(value, dict) else CtxList + return container(value, meta=meta) + else: + msg = "Unsupported value of type '{value}' in '{meta}'" + raise TypeError(msg) + + def __repr__(self): + return repr(self.data) + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = self._convert(key, value) + + def __delitem__(self, key): + del self.data[key] + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def __eq__(self, o): + return o.data == self.data + + def select(self, key: str): + index, *rems = key.split(sep=".", maxsplit=1) + index = index.strip() + index = self._key_transform(index) try: - value = _get_item(value, attr) - except KeyError: - msg = ( - f"Could not find '{attr}' " - "while substituting " - f"'{key}'.\n" - f"Interpolating with: {data}" - ) - raise ValueError(msg) - - if not isinstance(value, str) and isinstance(value, Collection): - raise ValueError( - f"Cannot interpolate value of type '{type(value).__name__}'" - ) - return value + d = self.data[index] + except LookupError as exc: + raise ValueError( + f"Could not find '{index}' in {self.data}" + ) from exc + return d.select(rems[0]) if rems else d + + def get_sources(self): + return {} + + +class CtxList(Container, MutableSequence): + _key_transform = staticmethod(int) + + def __init__(self, values: Sequence, meta: Meta = None): + super().__init__(meta=meta) + self.data: list = [] + self.extend(values) + + def insert(self, index: int, value): + self.data.insert(index, self._convert(index, value)) + + def get_sources(self): + return {self.meta.source: self.meta.path()} + + +class CtxDict(Container, MutableMapping): + def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs): + super().__init__(meta=meta) + + self.data: dict = {} + if mapping: + self.update(mapping) + self.update(kwargs) + + def __setitem__(self, key, value): + if not isinstance(key, str): + # limitation for the interpolation + # ignore other kinds of keys + return + return super().__setitem__(key, value) + + def merge_update(self, *args, overwrite=True): + for d in args: + _merge(self.data, d, overwrite=overwrite) + + +class Context(CtxDict): + def __init__(self, *args, **kwargs): + """ + Top level mutable dict, with some helpers to create context and track + """ + super().__init__(*args, **kwargs) + self._track = False + self._tracked_data = defaultdict(set) + + @contextmanager + def track(self): + self._track = True + yield + self._track = False + + def _track_data(self, node): + if not self._track: + return + + for source, keys in node.get_sources().items(): + if not source: + continue + params_file = self._tracked_data[source] + keys = [keys] if isinstance(keys, str) else keys + params_file.update(keys) + + @property + def tracked(self): + return [ + {file: list(keys)} for file, keys in self._tracked_data.items() + ] + + def select(self, key: str): + node = super().select(key) + self._track_data(node) + return node + + @classmethod + def load_from(cls, tree, file: str) -> "Context": + _, ext = os.path.splitext(file) + loader = LOADERS[ext] + + meta = Meta(source=file) + return cls(loader(file, tree=tree), meta=meta) + + @classmethod + def clone(cls, ctx: "Context") -> "Context": + """Clones given context.""" + return cls(deepcopy(ctx.data)) diff --git a/dvc/parsing/interpolate.py b/dvc/parsing/interpolate.py index 56e72f31d9..ca0eef0150 100644 --- a/dvc/parsing/interpolate.py +++ b/dvc/parsing/interpolate.py @@ -3,6 +3,8 @@ from funcy import rpartial +from dvc.parsing.context import Context, Value + KEYCRE = re.compile( r""" (? Date: Tue, 20 Oct 2020 11:09:37 +0545 Subject: [PATCH 2/4] Change varname --- dvc/parsing/__init__.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index 2a60eda5bf..7e3b1d073a 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -26,12 +26,10 @@ DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE PARAMS_KWD = "params" -DEFAULT_SENTINEL = object() - class DataResolver: - def __init__(self, repo: "Repo", yaml_wdir: PathInfo, d: dict): - to_import: PathInfo = yaml_wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE) + def __init__(self, repo: "Repo", wdir: PathInfo, d: dict): + to_import: PathInfo = wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE) vars_ = d.get(VARS_KWD, {}) if os.path.exists(to_import): self.global_ctx_source = to_import @@ -46,7 +44,7 @@ def __init__(self, repo: "Repo", yaml_wdir: PathInfo, d: dict): self.global_ctx.merge_update(vars_) self.data: dict = d - self._yaml_wdir = yaml_wdir + self.wdir = wdir self.repo = repo def _resolve_entry(self, name: str, definition): @@ -62,7 +60,7 @@ def resolve(self): def _resolve_stage(self, context: Context, name: str, definition) -> dict: definition = deepcopy(definition) wdir = self._resolve_wdir(context, definition.get(WDIR_KWD)) - if self._yaml_wdir != wdir: + if self.wdir != wdir: logger.debug( "Stage %s has different wdir than dvc.yaml file", name ) @@ -103,6 +101,6 @@ def _resolve_stage(self, context: Context, name: str, definition) -> dict: def _resolve_wdir(self, context: Context, wdir: str = None) -> PathInfo: if not wdir: - return self._yaml_wdir + return self.wdir wdir = resolve(wdir, context) - return self._yaml_wdir / str(wdir) + return self.wdir / str(wdir) From b576a6556a60f529e6d93d46a19dcc96d2ed9b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 20 Oct 2020 15:52:02 +0545 Subject: [PATCH 3/4] Add tests --- dvc/parsing/__init__.py | 3 +- dvc/parsing/context.py | 22 ++- dvc/parsing/interpolate.py | 7 +- tests/unit/test_context.py | 327 +++++++++++++++++++++++++++++++++++++ 4 files changed, 344 insertions(+), 15 deletions(-) create mode 100644 tests/unit/test_context.py diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index 7e3b1d073a..1e5168dde5 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -4,8 +4,7 @@ from itertools import starmap from typing import TYPE_CHECKING -from funcy import join -from funcy.seqs import first +from funcy import first, join from dvc.dependency.param import ParamsDependency from dvc.path_info import PathInfo diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index 215fade541..8211f5d885 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -25,7 +25,7 @@ def _merge(into, update, overwrite): @dataclass class Meta: - source: Optional[str] + source: Optional[str] = None dpaths: List[str] = field(default_factory=list) @staticmethod @@ -42,16 +42,19 @@ def path(self): return ".".join(self.dpaths) +def _default_meta(): + return Meta(source=None) + + @dataclass class Value: value: Any - meta: Meta = field(compare=False, repr=False) + meta: Meta = field( + compare=False, default_factory=_default_meta, repr=False + ) def __repr__(self): - return f"'{self}'" - - def __str__(self) -> str: - return str(self.value) + return repr(self.value) def get_sources(self): return {self.meta.source: self.meta.path()} @@ -75,7 +78,10 @@ def _convert(self, key, value): container = CtxDict if isinstance(value, dict) else CtxList return container(value, meta=meta) else: - msg = "Unsupported value of type '{value}' in '{meta}'" + msg = ( + "Unsupported value of type " + f"'{type(value).__name__}' in '{meta}'" + ) raise TypeError(msg) def __repr__(self): @@ -146,7 +152,7 @@ def __setitem__(self, key, value): return return super().__setitem__(key, value) - def merge_update(self, *args, overwrite=True): + def merge_update(self, *args, overwrite=False): for d in args: _merge(self.data, d, overwrite=overwrite) diff --git a/dvc/parsing/interpolate.py b/dvc/parsing/interpolate.py index ca0eef0150..124fb1c0ac 100644 --- a/dvc/parsing/interpolate.py +++ b/dvc/parsing/interpolate.py @@ -16,8 +16,6 @@ re.VERBOSE, ) -UNWRAP_DEFAULT = False - def _get_matches(template): return list(KEYCRE.finditer(template)) @@ -50,10 +48,9 @@ def _resolve_str(src: str, context): # replace "${enabled}", if `enabled` is a boolean, with it's actual # value rather than it's string counterparts. return _resolve_value(matches[0], context) - else: - # but not "${num} days" - src = _str_interpolate(src, matches, context) + # but not "${num} days" + src = _str_interpolate(src, matches, context) # regex already backtracks and avoids any `${` starting with # backslashes(`\`). We just need to replace those by `${`. return src.replace(r"\${", "${") diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py new file mode 100644 index 0000000000..54917ec99c --- /dev/null +++ b/tests/unit/test_context.py @@ -0,0 +1,327 @@ +from collections import defaultdict +from dataclasses import asdict +from math import pi + +import pytest +from funcy.seqs import first + +from dvc.parsing.context import Context, CtxDict, CtxList, Value +from dvc.tree.local import LocalTree +from dvc.utils.serialize import dump_yaml + + +def test_context(): + context = Context({"foo": "bar"}) + assert context["foo"] == Value("bar") + + context = Context(foo="bar") + assert context["foo"] == Value("bar") + + context["foobar"] = "foobar" + assert context["foobar"] == Value("foobar") + + del context["foobar"] + assert "foobar" not in context + assert "foo" in context + + with pytest.raises(KeyError): + context["foobar"] + + +def test_context_dict_ignores_keys_except_str(): + c = Context({"one": 1, 3: 3}) + assert "one" in c + assert 3 not in c + + c[3] = 3 + assert 3 not in c + + +def test_context_list(): + lst = ["foo", "bar", "baz"] + context = Context(lst=lst) + + assert context["lst"] == CtxList(lst) + assert context["lst"][0] == Value("foo") + del context["lst"][-1] + + assert "baz" not in context + + with pytest.raises(IndexError): + context["lst"][3] + + context["lst"].insert(0, "baz") + assert context["lst"] == CtxList(["baz"] + lst[:2]) + + +def test_context_setitem_getitem(): + context = Context() + lst = [1, 2, "three", True, pi, b"bytes", None] + context["list"] = lst + + assert isinstance(context["list"], CtxList) + assert context["list"] == CtxList(lst) + for i, val in enumerate(lst): + assert context["list"][i] == Value(val) + + d = { + "foo": "foo", + "bar": "bar", + "list": [ + {"foo0": "foo0", "bar0": "bar0"}, + {"foo1": "foo1", "bar1": "bar1"}, + ], + } + context["data"] = d + + assert isinstance(context["data"], CtxDict) + assert context["data"] == CtxDict(d) + assert context["data"]["foo"] == Value("foo") + assert context["data"]["bar"] == Value("bar") + + assert isinstance(context["data"]["list"], CtxList) + assert context["data"]["list"] == CtxList(d["list"]) + + for i, val in enumerate(d["list"]): + c = context["data"]["list"][i] + assert isinstance(c, CtxDict) + assert c == CtxDict(val) + assert c[f"foo{i}"] == Value(f"foo{i}") + assert c[f"bar{i}"] == Value(f"bar{i}") + + with pytest.raises(TypeError): + context["set"] = {1, 2, 3} + + +def test_loop_context(): + context = Context({"foo": "foo", "bar": "bar", "lst": [1, 2, 3]}) + + assert list(context) == ["foo", "bar", "lst"] + assert len(context) == 3 + + assert list(context["lst"]) == [Value(i) for i in [1, 2, 3]] + assert len(context["lst"]) == 3 + + assert list(context.items()) == [ + ("foo", Value("foo")), + ("bar", Value("bar")), + ("lst", CtxList([1, 2, 3])), + ] + + +def test_repr(): + data = {"foo": "foo", "bar": "bar", "lst": [1, 2, 3]} + context = Context(data) + + assert repr(context) == repr(data) + assert str(context) == str(data) + + +def test_select(): + context = Context(foo="foo", bar="bar", lst=[1, 2, 3]) + + assert context.select("foo") == Value("foo") + assert context.select("bar") == Value("bar") + assert context.select("lst") == CtxList([1, 2, 3]) + assert context.select("lst.0") == Value(1) + + with pytest.raises(ValueError): + context.select("baz") + + d = { + "lst": [ + {"foo0": "foo0", "bar0": "bar0"}, + {"foo1": "foo1", "bar1": "bar1"}, + ] + } + context = Context(d) + assert context.select("lst") == CtxList(d["lst"]) + assert context.select("lst.0") == CtxDict(d["lst"][0]) + assert context.select("lst.1") == CtxDict(d["lst"][1]) + + with pytest.raises(ValueError): + context.select("lst.2") + + for i, _ in enumerate(d["lst"]): + assert context.select(f"lst.{i}.foo{i}") == Value(f"foo{i}") + assert context.select(f"lst.{i}.bar{i}") == Value(f"bar{i}") + + +def test_merge_dict(): + d1 = {"Train": {"us": {"lr": 10}}} + d2 = {"Train": {"us": {"layers": 100}}} + + c1 = Context(d1) + c2 = Context(d2) + + c1.merge_update(c2) + assert c1.select("Train.us") == CtxDict(lr=10, layers=100) + + with pytest.raises(ValueError): + # cannot overwrite by default + c1.merge_update({"Train": {"us": {"lr": 15}}}) + + c1.merge_update({"Train": {"us": {"lr": 15}}}, overwrite=True) + assert c1.select("Train.us") == CtxDict(lr=15, layers=100) + + +def test_merge_list(): + c1 = Context(lst=[1, 2, 3]) + with pytest.raises(ValueError): + # cannot overwrite by default + c1.merge_update({"lst": [10, 11, 12]}) + + # lists are never merged + c1.merge_update({"lst": [10, 11, 12]}, overwrite=True) + assert c1.select("lst") == [10, 11, 12] + + +def test_overwrite_with_setitem(): + context = Context(foo="foo", d={"bar": "bar", "baz": "baz"}) + context["d"] = "overwrite" + assert "d" in context + assert context["d"] == Value("overwrite") + + +def test_load_from(mocker): + def _yaml_load(*args, **kwargs): + return {"x": {"y": {"z": 5}, "lst": [1, 2, 3]}, "foo": "foo"} + + mocker.patch("dvc.parsing.context.LOADERS", {".yaml": _yaml_load}) + file = "params.yaml" + c = Context.load_from(object(), file) + + assert asdict(c["x"].meta) == {"source": file, "dpaths": ["x"]} + assert asdict(c["foo"].meta) == {"source": file, "dpaths": ["foo"]} + assert asdict(c["x"]["y"].meta) == {"source": file, "dpaths": ["x", "y"]} + assert asdict(c["x"]["y"]["z"].meta) == { + "source": file, + "dpaths": ["x", "y", "z"], + } + assert asdict(c["x"]["lst"].meta) == { + "source": file, + "dpaths": ["x", "lst"], + } + assert asdict(c["x"]["lst"][0].meta) == { + "source": file, + "dpaths": ["x", "lst", "0"], + } + + +def test_clone(): + d = { + "lst": [ + {"foo0": "foo0", "bar0": "bar0"}, + {"foo1": "foo1", "bar1": "bar1"}, + ] + } + c1 = Context(d) + c2 = Context.clone(c1) + + c2["lst"][0]["foo0"] = "foo" + del c2["lst"][1]["foo1"] + + assert c1 != c2 + assert c1 == Context(d) + assert c2.select("lst.0.foo0") == Value("foo") + with pytest.raises(ValueError): + c2.select("lst.1.foo1") + + +def test_track(tmp_dir): + d = { + "lst": [ + {"foo0": "foo0", "bar0": "bar0"}, + {"foo1": "foo1", "bar1": "bar1"}, + ], + "dct": {"foo": "foo", "bar": "bar", "baz": "baz"}, + } + tree = LocalTree(None, config={}) + path = tmp_dir / "params.yaml" + dump_yaml(path, d, tree) + + context = Context.load_from(tree, str(path)) + + def key_tracked(key): + assert len(context.tracked) == 1 + return key in context.tracked[0][str(path)] + + with context.track(): + context.select("lst") + assert key_tracked("lst") + + context.select("dct") + assert not key_tracked("dct") + + context.select("dct.foo") + assert key_tracked("dct.foo") + + # Currently, it's unable to track dictionaries, as it can be merged + # from multiple sources. + context.select("lst.0") + assert not key_tracked("lst.0") + + # FIXME: either support tracking list values in ParamsDependency + # or, prevent this from being tracked. + context.select("lst.0.foo0") + assert key_tracked("lst.0.foo0") + + +def test_track_from_multiple_files(tmp_dir): + d1 = {"Train": {"us": {"lr": 10}}} + d2 = {"Train": {"us": {"layers": 100}}} + + tree = LocalTree(None, config={}) + path1 = tmp_dir / "params.yaml" + path2 = tmp_dir / "params2.yaml" + dump_yaml(path1, d1, tree) + dump_yaml(path2, d2, tree) + + context = Context.load_from(tree, str(path1)) + c = Context.load_from(tree, str(path2)) + context.merge_update(c) + + def key_tracked(path, key): + tracked = defaultdict(list) + for item in context.tracked: + source = first(item) + tracked[source].extend(item[source]) + return key in tracked[str(path)] + + with context.track(): + context.select("Train") + assert not (key_tracked(path1, "Train") or key_tracked(path2, "Train")) + + context.select("Train.us") + assert not ( + key_tracked(path1, "Train.us") or key_tracked(path2, "Train.us") + ) + + context.select("Train.us.lr") + assert key_tracked(path1, "Train.us.lr") and not key_tracked( + path2, "Train.us.lr" + ) + context.select("Train.us.layers") + assert not key_tracked(path1, "Train.us.layers") and key_tracked( + path2, "Train.us.layers" + ) + + context = Context.clone(context) + assert not context.tracked + + # let's see with an alias + context["us"] = context["Train"]["us"] + with context.track(): + context.select("us") + assert not ( + key_tracked(path1, "Train.us") or key_tracked(path2, "Train.us") + ) + + context.select("us.lr") + assert key_tracked(path1, "Train.us.lr") and not key_tracked( + path2, "Train.us.lr" + ) + context.select("Train.us.layers") + assert not key_tracked(path1, "Train.us.layers") and key_tracked( + path2, "Train.us.layers" + ) From 92bb9ba5dce5945a18b00717cc7818b707ead774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 20 Oct 2020 16:08:40 +0545 Subject: [PATCH 4/4] fix pylint issues --- tests/unit/test_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 54917ec99c..cd8e838635 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -25,7 +25,7 @@ def test_context(): assert "foo" in context with pytest.raises(KeyError): - context["foobar"] + _ = context["foobar"] def test_context_dict_ignores_keys_except_str(): @@ -48,7 +48,7 @@ def test_context_list(): assert "baz" not in context with pytest.raises(IndexError): - context["lst"][3] + _ = context["lst"][3] context["lst"].insert(0, "baz") assert context["lst"] == CtxList(["baz"] + lst[:2])