Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,29 @@ def loads_from(stage, s_list, erepo=None):
return ret


def _parse_params(path_params):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this function to repo.run().

path, _, params_str = path_params.rpartition(":")
params = params_str.split(",")
return path, params
def _merge_params(s_list):
d = defaultdict(list)
default_file = ParamsDependency.DEFAULT_PARAMS_FILE
for key in s_list:
if isinstance(key, str):
d[default_file].append(key)
continue
if not isinstance(key, dict):
msg = "Only list of str/dict is supported. Got: "
msg += f"'{type(key).__name__}'."
raise ValueError(msg)

for k, params in key.items():
if not isinstance(params, list):
msg = "Expected list of params for custom params file "
msg += f"'{k}', got '{type(params).__name__}'."
raise ValueError(msg)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't want to add assertions, so, I have used ValueError. They are purely internal, as error are likely to be raised by schema validators before even reaching here.

d[k].extend(params)
return d


def loads_params(stage, s_list):
# Creates an object for each unique file that is referenced in the list
params_by_path = defaultdict(list)
for s in s_list:
path, params = _parse_params(s)
params_by_path[path].extend(params)

d_list = []
for path, params in params_by_path.items():
d_list.append(
{
BaseOutput.PARAM_PATH: path,
ParamsDependency.PARAM_PARAMS: params,
}
)

return loadd_from(stage, d_list)
d = _merge_params(s_list)
return [
ParamsDependency(stage, path, params) for path, params in d.items()
]
2 changes: 1 addition & 1 deletion dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, stage, path, params):
info=info,
)

def _dyn_load(self, values=None):
def fill_values(self, values=None):
"""Load params values dynamically."""
if not values:
return
Expand Down
48 changes: 47 additions & 1 deletion dvc/output/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import defaultdict
from urllib.parse import urlparse

from funcy import collecting, project
from voluptuous import And, Any, Coerce, Length, Lower, Required, SetTo

from dvc.output.base import BaseOutput
Expand Down Expand Up @@ -58,7 +60,9 @@
SCHEMA[BaseOutput.PARAM_PERSIST] = bool


def _get(stage, p, info, cache, metric, plot=False, persist=False):
def _get(
stage, p, info=None, cache=True, metric=False, plot=False, persist=False
):
parsed = urlparse(p)

if parsed.scheme == "remote":
Expand Down Expand Up @@ -135,3 +139,45 @@ def loads_from(
)
for s in s_list
]


def _split_dict(d, keys):
return project(d, keys), project(d, d.keys() - keys)


def _merge_data(s_list):
d = defaultdict(dict)
for key in s_list:
if isinstance(key, str):
d[key].update({})
continue
if not isinstance(key, dict):
raise ValueError(f"'{type(key).__name__}' not supported.")

for k, flags in key.items():
if not isinstance(flags, dict):
raise ValueError(
f"Expected dict for '{k}', got: '{type(flags).__name__}'"
)
d[k].update(flags)
return d


@collecting
def load_from_pipeline(stage, s_list, typ="outs"):
if typ not in (stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS):
raise ValueError(f"'{typ}' key is not allowed for pipeline files.")

metric = typ == stage.PARAM_METRICS
plot = typ == stage.PARAM_PLOTS

d = _merge_data(s_list)

for path, flags in d.items():
plt_d = {}
if plot:
from dvc.schema import PLOT_PROPS

plt_d, flags = _split_dict(flags, keys=PLOT_PROPS.keys())
extra = project(flags, ["cache", "persist"])
yield _get(stage, path, {}, plot=plt_d or plot, metric=metric, **extra)
20 changes: 18 additions & 2 deletions dvc/repo/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from funcy import concat, first
from funcy import concat, first, lfilter

from dvc.exceptions import InvalidArgumentError
from dvc.stage.exceptions import (
Expand All @@ -20,6 +20,19 @@ def is_valid_name(name: str):
return not INVALID_STAGENAME_CHARS & set(name)


def parse_params(path_params):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function outputs same structure as present in pipeline file.

ret = []
for path_param in path_params:
path, _, params_str = path_param.rpartition(":")
# remove empty strings from params, on condition such as `-p "file1:"`
params = lfilter(bool, params_str.split(","))
if not path:
ret.extend(params)
else:
ret.append({path: params})
return ret


def _get_file_path(kwargs):
from dvc.dvcfile import DVC_FILE_SUFFIX, DVC_FILE

Expand Down Expand Up @@ -72,7 +85,10 @@ def run(self, fname=None, no_exec=False, single_stage=False, **kwargs):
if not is_valid_name(stage_name):
raise InvalidStageName

stage = create_stage(stage_cls, repo=self, path=path, **kwargs)
params = parse_params(kwargs.pop("params", []))
stage = create_stage(
stage_cls, repo=self, path=path, params=params, **kwargs
)
if stage is None:
return None

Expand Down
124 changes: 16 additions & 108 deletions dvc/stage/loader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
import os
from collections import defaultdict
from collections.abc import Mapping
from copy import deepcopy
from itertools import chain

from funcy import first
from funcy import lcat, project

from dvc import dependency, output

from ..dependency import ParamsDependency
from . import fill_stage_dependencies
from .exceptions import StageNameUnspecified, StageNotFound

logger = logging.getLogger(__name__)
Expand All @@ -32,6 +32,7 @@ def __init__(self, dvcfile, stages_data, lockfile_data=None):

@staticmethod
def fill_from_lock(stage, lock_data):
"""Fill values for params, checksums for outs and deps from lock."""
from .params import StageParams

items = chain(
Expand All @@ -46,8 +47,8 @@ def fill_from_lock(stage, lock_data):
for key, item in items:
if isinstance(item, ParamsDependency):
# load the params with values inside lock dynamically
params = lock_data.get("params", {}).get(item.def_path, {})
item._dyn_load(params)
lock_params = lock_data.get(stage.PARAM_PARAMS, {})
item.fill_values(lock_params.get(item.def_path, {}))
continue

item.checksum = (
Expand All @@ -56,104 +57,6 @@ def fill_from_lock(stage, lock_data):
.get(item.checksum_type)
)

@classmethod
def _load_params(cls, stage, pipeline_params):
"""
File in pipeline file is expected to be in following format:
```
params:
- lr
- train.epochs
- params2.yaml: # notice the filename
- process.threshold
- process.bow
```

and, in lockfile, we keep it as following format:
```
params:
params.yaml:
lr: 0.0041
train.epochs: 100
params2.yaml:
process.threshold: 0.98
process.bow:
- 15000
- 123
```
In the list of `params` inside pipeline file, if any of the item is
dict-like, the key will be treated as separate params file and it's
values to be part of that params file, else, the item is considered
as part of the `params.yaml` which is a default file.

(From example above: `lr` is considered to be part of `params.yaml`
whereas `process.bow` to be part of `params2.yaml`.)

We only load the keys here, lockfile bears the values which are used
to compare between the actual params from the file in the workspace.
"""
res = defaultdict(list)
for key in pipeline_params:
if isinstance(key, str):
path = DEFAULT_PARAMS_FILE
res[path].append(key)
elif isinstance(key, dict):
path = first(key)
res[path].extend(key[path])

stage.deps.extend(
dependency.loadd_from(
stage,
[
{"path": key, "params": params}
for key, params in res.items()
],
)
)

@classmethod
def _load_outs(cls, stage, data, typ=None):
from dvc.output.base import BaseOutput

d = []
for key in data:
if isinstance(key, str):
entry = {BaseOutput.PARAM_PATH: key}
if typ:
entry[typ] = True
d.append(entry)
continue

assert isinstance(key, dict)
assert len(key) == 1

path = first(key)
extra = key[path]

if not typ:
d.append({BaseOutput.PARAM_PATH: path, **extra})
continue

entry = {BaseOutput.PARAM_PATH: path}

persist = extra.pop(BaseOutput.PARAM_PERSIST, False)
if persist:
entry[BaseOutput.PARAM_PERSIST] = persist

cache = extra.pop(BaseOutput.PARAM_CACHE, True)
if not cache:
entry[BaseOutput.PARAM_CACHE] = cache

entry[typ] = extra or True

d.append(entry)

stage.outs.extend(output.loadd_from(stage, d))

@classmethod
def _load_deps(cls, stage, data):
stage.deps.extend(dependency.loads_from(stage, data))

@classmethod
def load_stage(cls, dvcfile, name, stage_data, lock_data):
from . import PipelineStage, Stage, loads_from
Expand All @@ -163,13 +66,18 @@ def load_stage(cls, dvcfile, name, stage_data, lock_data):
)
stage = loads_from(PipelineStage, dvcfile.repo, path, wdir, stage_data)
stage.name = name
stage.deps, stage.outs = [], []

cls._load_outs(stage, stage_data.get("outs", []))
cls._load_outs(stage, stage_data.get("metrics", []), "metric")
cls._load_outs(stage, stage_data.get("plots", []), "plot")
cls._load_deps(stage, stage_data.get("deps", []))
cls._load_params(stage, stage_data.get("params", []))
deps = project(stage_data, [stage.PARAM_DEPS, stage.PARAM_PARAMS])
fill_stage_dependencies(stage, **deps)

outs = project(
stage_data,
[stage.PARAM_OUTS, stage.PARAM_METRICS, stage.PARAM_PLOTS],
)
stage.outs = lcat(
output.load_from_pipeline(stage, data, typ=key)
for key, data in outs.items()
)

if lock_data:
stage.cmd_changed = lock_data.get(
Expand Down
29 changes: 25 additions & 4 deletions tests/unit/dependency/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,40 @@

def test_loads_params(dvc):
stage = Stage(dvc)
deps = loads_params(stage, ["foo", "bar,baz", "a_file:qux"])
assert len(deps) == 2
deps = loads_params(
stage,
[
"foo",
"bar",
{"a_file": ["baz", "bat"]},
{"b_file": ["cat"]},
{},
{"a_file": ["foobar"]},
],
)
assert len(deps) == 3

assert isinstance(deps[0], ParamsDependency)
assert deps[0].def_path == ParamsDependency.DEFAULT_PARAMS_FILE
assert deps[0].params == ["foo", "bar", "baz"]
assert deps[0].params == ["foo", "bar"]
assert deps[0].info == {}

assert isinstance(deps[1], ParamsDependency)
assert deps[1].def_path == "a_file"
assert deps[1].params == ["qux"]
assert deps[1].params == ["baz", "bat", "foobar"]
assert deps[1].info == {}

assert isinstance(deps[2], ParamsDependency)
assert deps[2].def_path == "b_file"
assert deps[2].params == ["cat"]
assert deps[2].info == {}


@pytest.mark.parametrize("params", [[3], [{"b_file": "cat"}]])
def test_params_error(dvc, params):
with pytest.raises(ValueError):
loads_params(Stage(dvc), params)


def test_loadd_from(dvc):
stage = Stage(dvc)
Expand Down
Loading