From 9204072ec74a42287261934561ae452416ebcc40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Redzy=C5=84ski?= Date: Thu, 22 Jul 2021 12:11:15 +0200 Subject: [PATCH 1/2] plots: cleanup data extraction --- dvc/repo/plots/data.py | 54 +++++------------------------- setup.py | 1 - tests/func/plots/test_show.py | 12 ++----- tests/unit/repo/plots/test_data.py | 19 +---------- 4 files changed, 13 insertions(+), 73 deletions(-) diff --git a/dvc/repo/plots/data.py b/dvc/repo/plots/data.py index 5268d81152..9ce7518ab2 100644 --- a/dvc/repo/plots/data.py +++ b/dvc/repo/plots/data.py @@ -33,14 +33,10 @@ def __init__(self, path, revision): def plot_data(filename, revision, content): _, extension = os.path.splitext(filename.lower()) - if extension == ".json": - return JSONPlotData(filename, revision, content) - if extension == ".csv": - return CSVPlotData(filename, revision, content) - if extension == ".tsv": - return CSVPlotData(filename, revision, content, delimiter="\t") - if extension == ".yaml": - return YAMLPlotData(filename, revision, content) + if extension in (".json", ".yaml"): + return DictData(filename, revision, content) + if extension in (".csv", ".tsv"): + return ListData(filename, revision, content) raise PlotMetricTypeError(filename) @@ -68,34 +64,6 @@ def _filter_fields(data_points, filename, revision, fields=None, **kwargs): return new_data -def _apply_path(data, path=None, **kwargs): - if not path or not isinstance(data, dict): - return data - - import jsonpath_ng - - found = jsonpath_ng.parse(path).find(data) - first_datum = first(found) - if ( - len(found) == 1 - and isinstance(first_datum.value, list) - and isinstance(first(first_datum.value), dict) - ): - data_points = first_datum.value - elif len(first_datum.path.fields) == 1: - field_name = first(first_datum.path.fields) - data_points = [{field_name: datum.value} for datum in found] - else: - raise PlotDataStructureError() - - if not isinstance(data_points, list) or not ( - isinstance(first(data_points), dict) - ): - raise PlotDataStructureError() - - return data_points - - def _lists(dictionary): for _, value in dictionary.items(): if isinstance(value, dict): @@ -158,17 +126,13 @@ def to_datapoints(self, **kwargs): return data -class JSONPlotData(PlotData): +class DictData(PlotData): + # For files usually parsed as dicts: eg JSON, Yaml def _processors(self): parent_processors = super()._processors() - return [_apply_path, _find_data] + parent_processors + return [_find_data] + parent_processors -class CSVPlotData(PlotData): +class ListData(PlotData): + # For files parsed as list: CSV, TSV pass - - -class YAMLPlotData(PlotData): - def _processors(self): - parent_processors = super()._processors() - return [_find_data] + parent_processors diff --git a/setup.py b/setup.py index 089362a82e..9c4f5efdfd 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,6 @@ def run(self): "nanotime>=0.5.2", "pyasn1>=0.4.1", "voluptuous>=0.11.7", - "jsonpath-ng>=1.5.1", "requests>=2.22.0", "grandalf==0.6", "distro>=1.3.0", diff --git a/tests/func/plots/test_show.py b/tests/func/plots/test_show.py index 8586506a42..66018fc5e4 100644 --- a/tests/func/plots/test_show.py +++ b/tests/func/plots/test_show.py @@ -12,12 +12,7 @@ from dvc.main import main from dvc.path_info import PathInfo from dvc.repo import Repo -from dvc.repo.plots.data import ( - JSONPlotData, - PlotData, - PlotMetricTypeError, - YAMLPlotData, -) +from dvc.repo.plots.data import DictData, PlotData, PlotMetricTypeError from dvc.repo.plots.template import ( BadTemplateError, NoFieldInDataError, @@ -560,12 +555,11 @@ def test_raise_on_wrong_field(tmp_dir, scm, dvc, run_copy_metrics): dvc.plots.show("metric.json", props={"y": "no_val"}) -@pytest.mark.parametrize("data_class", [JSONPlotData, YAMLPlotData]) -def test_find_data_in_dict(tmp_dir, data_class): +def test_find_data_in_dict(tmp_dir): metric = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}] dmetric = {"train": metric} - plot_data = data_class("-", "revision", dmetric) + plot_data = DictData("-", "revision", dmetric) expected = metric for d in expected: diff --git a/tests/unit/repo/plots/test_data.py b/tests/unit/repo/plots/test_data.py index 687d629796..1771303d63 100644 --- a/tests/unit/repo/plots/test_data.py +++ b/tests/unit/repo/plots/test_data.py @@ -2,24 +2,7 @@ import pytest -from dvc.repo.plots.data import _apply_path, _find_data, _lists - - -@pytest.mark.parametrize( - "path,expected_result", - [ - ("$.some.path[*].a", [{"a": 1}, {"a": 4}]), - ("$.some.path", [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]), - ], -) -def test_parse_json(path, expected_result): - value = { - "some": {"path": [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]} - } - - result = _apply_path(value, path=path) - - assert result == expected_result +from dvc.repo.plots.data import _find_data, _lists @pytest.mark.parametrize( From dd1492de6d786c28d88e4f862e3f5207d5771c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Redzy=C5=84ski?= Date: Thu, 22 Jul 2021 12:46:34 +0200 Subject: [PATCH 2/2] fixup --- tests/func/plots/test_show.py | 15 +-------------- tests/unit/repo/plots/test_data.py | 25 +++++++++++++++++++------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/func/plots/test_show.py b/tests/func/plots/test_show.py index 66018fc5e4..bdc32e4e8f 100644 --- a/tests/func/plots/test_show.py +++ b/tests/func/plots/test_show.py @@ -12,7 +12,7 @@ from dvc.main import main from dvc.path_info import PathInfo from dvc.repo import Repo -from dvc.repo.plots.data import DictData, PlotData, PlotMetricTypeError +from dvc.repo.plots.data import PlotData, PlotMetricTypeError from dvc.repo.plots.template import ( BadTemplateError, NoFieldInDataError, @@ -555,19 +555,6 @@ def test_raise_on_wrong_field(tmp_dir, scm, dvc, run_copy_metrics): dvc.plots.show("metric.json", props={"y": "no_val"}) -def test_find_data_in_dict(tmp_dir): - metric = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}] - dmetric = {"train": metric} - - plot_data = DictData("-", "revision", dmetric) - - expected = metric - for d in expected: - d["rev"] = "revision" - - assert list(map(dict, plot_data.to_datapoints())) == expected - - def test_multiple_plots(tmp_dir, scm, dvc, run_copy_metrics): metric1 = [ OrderedDict([("first_val", 100), ("second_val", 100), ("val", 2)]), diff --git a/tests/unit/repo/plots/test_data.py b/tests/unit/repo/plots/test_data.py index 1771303d63..93f2975234 100644 --- a/tests/unit/repo/plots/test_data.py +++ b/tests/unit/repo/plots/test_data.py @@ -1,8 +1,9 @@ from collections import OrderedDict +from typing import Dict, List import pytest -from dvc.repo.plots.data import _find_data, _lists +from dvc.repo.plots.data import DictData, _lists @pytest.mark.parametrize( @@ -22,10 +23,22 @@ def test_finding_lists(dictionary, expected_result): assert list(result) == expected_result -@pytest.mark.parametrize("fields", [{"x"}, set()]) -def test_finding_data(fields): - data = {"a": {"b": [{"x": 2, "y": 3}, {"x": 1, "y": 5}]}} +def test_find_data_in_dict(tmp_dir): + m1 = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}] + m2 = [{"x": 1}, {"x": 2}] + dmetric = OrderedDict([("t1", m1), ("t2", m2)]) - result = _find_data(data, fields=fields) + plot_data = DictData("-", "revision", dmetric) - assert result == [{"x": 2, "y": 3}, {"x": 1, "y": 5}] + def points_with(datapoints: List, additional_info: Dict): + for datapoint in datapoints: + datapoint.update(additional_info) + + return datapoints + + assert list(map(dict, plot_data.to_datapoints())) == points_with( + m1, {"rev": "revision"} + ) + assert list( + map(dict, plot_data.to_datapoints(fields={"x"})) + ) == points_with(m2, {"rev": "revision"})