diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index a0a65d59a2..0136536f6c 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -529,15 +529,24 @@ def template(self) -> "DictStrAny": return self._template @cached_property - def resolved_iterable(self) -> dict[str, list]: + def resolved_iterable(self) -> dict[str, Union[list, Mapping]]: return self._resolve_matrix_data() - def _resolve_matrix_data(self) -> dict[str, list]: + def _resolve_matrix_data(self) -> dict[str, Union[list, Mapping]]: try: iterable = self.context.resolve(self.matrix_data, unwrap=False) except (ContextError, ParseError) as exc: format_and_raise(exc, f"'{self.where}.{self.name}.matrix'", self.relpath) + for key, value in iterable.items(): + if not is_map_or_seq(value): + node = value.value if isinstance(value, Node) else value + typ = type(node).__name__ + raise ResolveError( + f"failed to resolve '{self.where}.{self.name}.matrix.{key}'" + f" in '{self.relpath}': expected list/dictionary, got {typ}" + ) + # Matrix entries will have `key` and `item` added to the context. # Warn users if these are already in the context from the global vars. self._warn_if_overwriting([self.pair.key, self.pair.value]) @@ -564,13 +573,27 @@ def normalized_iterable(self) -> dict[str, "DictStrAny"]: assert isinstance(iterable, Mapping) ret: dict[str, DictStrAny] = {} - matrix = {key: enumerate(v) for key, v in iterable.items()} + matrix = {} + for key, value in iterable.items(): + if isinstance(value, Mapping): + # For mappings, use (key, value) pairs. + # Key is used for naming, value is used for context. + matrix[key] = [(to_str(k), v) for k, v in value.items()] + else: + # For sequences, use (name, value) pairs. + # Name is index-based for complex items, or value-based for simple items + items = [] + for i, v in enumerate(value): + name = f"{key}{i}" if is_map_or_seq(v) else to_str(v) + items.append((name, v)) + matrix[key] = items + for combination in product(*matrix.values()): d: DictStrAny = {} fragments: list[str] = [] - for k, (i, v) in zip(matrix.keys(), combination): + for k, (name, v) in zip(matrix.keys(), combination): d[k] = v - fragments.append(f"{k}{i}" if is_map_or_seq(v) else to_str(v)) + fragments.append(name) key = "-".join(fragments) ret[key] = d diff --git a/dvc/schema.py b/dvc/schema.py index 85001016fc..9ec9bd9ded 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -80,7 +80,7 @@ VARS_SCHEMA = [str, dict] STAGE_DEFINITION = { - MATRIX_KWD: {str: vol.Any(str, list)}, + MATRIX_KWD: {str: vol.Any(str, list, dict)}, vol.Required(StageParams.PARAM_CMD): vol.Any(str, list), vol.Optional(StageParams.PARAM_WDIR): str, vol.Optional(StageParams.PARAM_DEPS): [str], diff --git a/tests/func/parsing/test_matrix.py b/tests/func/parsing/test_matrix.py index 26ffb527b1..a319b916fc 100644 --- a/tests/func/parsing/test_matrix.py +++ b/tests/func/parsing/test_matrix.py @@ -1,6 +1,7 @@ import pytest -from dvc.parsing import DataResolver, MatrixDefinition +from dvc.parsing import DataResolver, MatrixDefinition, ResolveError +from dvc.schema import COMPILED_MULTI_STAGE_SCHEMA MATRIX_DATA = { "os": ["win", "linux"], @@ -91,3 +92,90 @@ def test_matrix_key_present(tmp_dir, dvc, matrix): "build@linux-3.8-dict1-list0": {"cmd": "echo linux-3.8-dict1-list0"}, "build@linux-3.8-dict1-list1": {"cmd": "echo linux-3.8-dict1-list1"}, } + + +def test_matrix_schema_allows_mapping(): + data = { + "stages": { + "build": { + "matrix": {"models": {"goo": {"val": 1}, "baz": {"val": 2}}}, + "cmd": "echo ${item.models.val}", + } + } + } + COMPILED_MULTI_STAGE_SCHEMA(data) + + +MAPPING_MATRIX_DATA = {"goo": {"val": 1}, "baz": {"val": 2}} + + +@pytest.mark.parametrize( + "matrix", + [ + {"models": MAPPING_MATRIX_DATA}, + {"models": "${map_param}"}, + ], +) +def test_matrix_with_mapping(tmp_dir, dvc, matrix): + (tmp_dir / "params.yaml").dump({"map_param": MAPPING_MATRIX_DATA}) + resolver = DataResolver(dvc, tmp_dir.fs_path, {}) + data = {"matrix": matrix, "cmd": "echo ${item.models.val}"} + definition = MatrixDefinition(resolver, resolver.context, "build", data) + + assert definition.resolve_all() == { + "build@goo": {"cmd": "echo 1"}, + "build@baz": {"cmd": "echo 2"}, + } + + +@pytest.mark.parametrize( + "matrix", + [ + {"models": MAPPING_MATRIX_DATA, "ver": [1, 2]}, + {"models": "${map_param}", "ver": "${ver}"}, + ], +) +def test_matrix_mixed_mapping_and_list(tmp_dir, dvc, matrix): + (tmp_dir / "params.yaml").dump({"map_param": MAPPING_MATRIX_DATA, "ver": [1, 2]}) + resolver = DataResolver(dvc, tmp_dir.fs_path, {}) + data = {"matrix": matrix, "cmd": "echo ${item.models.val} ${item.ver}"} + definition = MatrixDefinition(resolver, resolver.context, "build", data) + + assert definition.resolve_all() == { + "build@goo-1": {"cmd": "echo 1 1"}, + "build@goo-2": {"cmd": "echo 1 2"}, + "build@baz-1": {"cmd": "echo 2 1"}, + "build@baz-2": {"cmd": "echo 2 2"}, + } + + +@pytest.mark.parametrize( + "matrix", + [ + {"models": MAPPING_MATRIX_DATA}, + {"models": "${map_param}"}, + ], +) +def test_matrix_mapping_key_present(tmp_dir, dvc, matrix): + (tmp_dir / "params.yaml").dump({"map_param": MAPPING_MATRIX_DATA}) + resolver = DataResolver(dvc, tmp_dir.fs_path, {}) + data = {"matrix": matrix, "cmd": "echo ${key}"} + definition = MatrixDefinition(resolver, resolver.context, "build", data) + + assert definition.resolve_all() == { + "build@goo": {"cmd": "echo goo"}, + "build@baz": {"cmd": "echo baz"}, + } + + +@pytest.mark.parametrize("matrix_value", ["${foo}", "${dct.model1}", "foobar"]) +def test_matrix_expects_list_or_dict(tmp_dir, dvc, matrix_value): + (tmp_dir / "params.yaml").dump({"foo": "bar", "dct": {"model1": "a-out"}}) + resolver = DataResolver(dvc, tmp_dir.fs_path, {}) + data = {"matrix": {"dim": matrix_value}, "cmd": "echo ${item.dim}"} + definition = MatrixDefinition(resolver, resolver.context, "build", data) + + with pytest.raises(ResolveError) as exc_info: + definition.resolve_all() + assert "expected list/dictionary, got str" in str(exc_info.value) + assert "stages.build.matrix.dim" in str(exc_info.value)