diff --git a/CHANGELOG.md b/CHANGELOG.md index 86f7f067..af32256c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +* Created tests for the vector module (@nkorinek, #209) * Created functions to test point geometries in VectorTester (@nkorinek, #176) * made `assert_string_contains()` accept correct strings with spaces in them (@nkorinek, #182) * added contributors file and updated README to remove that information (@nkorinek, #121) diff --git a/matplotcheck/tests/conftest.py b/matplotcheck/tests/conftest.py index 17217454..0bad338a 100644 --- a/matplotcheck/tests/conftest.py +++ b/matplotcheck/tests/conftest.py @@ -1,10 +1,10 @@ """Pytest fixtures for matplotcheck tests""" import pytest -import pandas as pd -import geopandas as gpd -from shapely.geometry import Polygon import numpy as np import matplotlib.pyplot as plt +import pandas as pd +import geopandas as gpd +from shapely.geometry import Polygon, LineString from matplotcheck.base import PlotTester @@ -32,14 +32,14 @@ def pd_gdf(): """Create a geopandas GeoDataFrame for testing""" df = pd.DataFrame( { - "lat": np.random.randint(-85, 85, size=100), - "lon": np.random.randint(-180, 180, size=100), + "lat": np.random.randint(-85, 85, size=5), + "lon": np.random.randint(-180, 180, size=5), } ) gdf = gpd.GeoDataFrame( - {"A": np.arange(100)}, geometry=gpd.points_from_xy(df.lon, df.lat) + {"A": np.arange(5)}, geometry=gpd.points_from_xy(df.lon, df.lat) ) - + gdf["attr"] = ["Tree", "Tree", "Bush", "Bush", "Bush"] return gdf @@ -68,6 +68,15 @@ def basic_polygon_gdf(basic_polygon): return gdf +@pytest.fixture +def two_line_gdf(): + """ Create Line Objects For Testing """ + linea = LineString([(1, 1), (2, 2), (3, 2), (5, 3)]) + lineb = LineString([(3, 4), (5, 7), (12, 2), (10, 5), (9, 7.5)]) + gdf = gpd.GeoDataFrame([1, 2], geometry=[linea, lineb], crs="epsg:4326") + return gdf + + @pytest.fixture def pd_xlabels(): """Create a DataFrame which uses the column labels as x-data.""" @@ -85,9 +94,7 @@ def pt_scatter_plt(pd_df): ax.set_xlabel("x label") ax.set_ylabel("y label") - axis = plt.gca() - - return PlotTester(axis) + return PlotTester(ax) @pytest.fixture @@ -110,9 +117,7 @@ def pt_line_plt(pd_df): ax_position.ymax - 0.25, ax_position.ymin - 0.075, "Figure Caption" ) - axis = plt.gca() - - return PlotTester(axis) + return PlotTester(ax) @pytest.fixture @@ -123,9 +128,7 @@ def pt_multi_line_plt(pd_df): ax.set_ylim((0, 140)) ax.legend(loc="center left", title="Legend", bbox_to_anchor=(1, 0.5)) - axis = plt.gca() - - return PlotTester(axis) + return PlotTester(ax) @pytest.fixture @@ -138,9 +141,7 @@ def pt_bar_plt(pd_df): ax.set_xlabel("x label") ax.set_ylabel("y label") - axis = plt.gca() - - return PlotTester(axis) + return PlotTester(ax) @pytest.fixture @@ -150,21 +151,22 @@ def pt_time_line_plt(pd_df_timeseries): pd_df_timeseries.plot("time", "A", kind="line", ax=ax) - axis = plt.gca() - - return PlotTester(axis) + return PlotTester(ax) @pytest.fixture def pt_geo_plot(pd_gdf): """Create a geo plot for testing""" fig, ax = plt.subplots() + size = 0 + point_symb = {"Tree": "green", "Bush": "brown"} - pd_gdf.plot(ax=ax) - ax.set_title("My Plot Title", fontsize=30) - ax.set_xlabel("x label") - ax.set_ylabel("y label") + for ctype, points in pd_gdf.groupby("attr"): + color = point_symb[ctype] + label = ctype + size += 100 + points.plot(color=color, ax=ax, label=label, markersize=size) - axis = plt.gca() + ax.legend(title="Legend", loc=(1.1, 0.1)) - return PlotTester(axis) + return PlotTester(ax) diff --git a/matplotcheck/tests/test_lines.py b/matplotcheck/tests/test_lines.py new file mode 100644 index 00000000..14d68c12 --- /dev/null +++ b/matplotcheck/tests/test_lines.py @@ -0,0 +1,171 @@ +"""Tests for the vector module""" +import matplotlib +import matplotlib.pyplot as plt +import pytest +import geopandas as gpd +from shapely.geometry import LineString + +from matplotcheck.vector import VectorTester + +matplotlib.use("Agg") + + +@pytest.fixture +def multi_line_gdf(two_line_gdf): + """ Create a multi-line GeoDataFrame. + This has one multi line and another regular line. + """ + # Create a single and multi line object + multiline_feat = two_line_gdf.unary_union + linec = LineString([(2, 1), (3, 1), (4, 1), (5, 2)]) + out_df = gpd.GeoDataFrame( + geometry=gpd.GeoSeries([multiline_feat, linec]), crs="epsg:4326", + ) + out_df["attr"] = ["road", "stream"] + return out_df + + +@pytest.fixture +def mixed_type_geo_plot(pd_gdf, multi_line_gdf): + """Create a point plot for testing""" + _, ax = plt.subplots() + + pd_gdf.plot(ax=ax) + multi_line_gdf.plot(ax=ax) + + return VectorTester(ax) + + +@pytest.fixture +def line_geo_plot(two_line_gdf): + """Create a line vector tester object.""" + _, ax = plt.subplots() + + two_line_gdf.plot(ax=ax) + + return VectorTester(ax) + + +@pytest.fixture +def multiline_geo_plot(multi_line_gdf): + """Create a multiline vector tester object.""" + _, ax = plt.subplots() + + multi_line_gdf.plot(ax=ax, column="attr") + + return VectorTester(ax) + + +@pytest.fixture +def multiline_geo_plot_bad(multi_line_gdf): + """Create a multiline vector tester object.""" + _, ax = plt.subplots() + + multi_line_gdf.plot(ax=ax) + + return VectorTester(ax) + + +def test_assert_line_geo(line_geo_plot, two_line_gdf): + """Test that lines are asserted correctly""" + line_geo_plot.assert_lines(two_line_gdf) + plt.close("all") + + +def test_assert_multiline_geo(multiline_geo_plot, multi_line_gdf): + """Test that multi lines are asserted correctly""" + multiline_geo_plot.assert_lines(multi_line_gdf) + plt.close("all") + + +def test_assert_line_geo_fail(line_geo_plot, multi_line_gdf): + """Test that lines fail correctly""" + with pytest.raises(AssertionError, match="Incorrect Line Data"): + line_geo_plot.assert_lines(multi_line_gdf) + plt.close("all") + + +def test_assert_multiline_geo_fail(multiline_geo_plot, two_line_gdf): + """Test that multi lines fail correctly""" + with pytest.raises(AssertionError, match="Incorrect Line Data"): + multiline_geo_plot.assert_lines(two_line_gdf) + plt.close("all") + + +def test_assert_line_fails_list(line_geo_plot): + """Test that assert_lines fails when passed a list""" + linelist = [ + [(1, 1), (2, 2), (3, 2), (5, 3)], + [(3, 4), (5, 7), (12, 2), (10, 5), (9, 7.5)], + ] + with pytest.raises(ValueError, match="lines_expected is not expected ty"): + line_geo_plot.assert_lines(linelist) + plt.close("all") + + +def test_assert_line_geo_passed_nothing(line_geo_plot): + """Test that assertion passes when passed None""" + line_geo_plot.assert_lines(None) + plt.close("all") + + +def test_get_lines_geometry(line_geo_plot): + """Test that get_lines returns the proper values""" + lines = [(LineString(i[0])) for i in line_geo_plot.get_lines().values] + geometries = gpd.GeoDataFrame(geometry=lines) + line_geo_plot.assert_lines(geometries) + plt.close("all") + + +def test_assert_lines_grouped_by_type(multiline_geo_plot, multi_line_gdf): + """Test that assert works for grouped line plots""" + multiline_geo_plot.assert_lines_grouped_by_type(multi_line_gdf, "attr") + plt.close("all") + + +def test_assert_lines_grouped_by_type_fail( + multiline_geo_plot_bad, multi_line_gdf +): + """Test that assert fails for incorrectly grouped line plots""" + with pytest.raises(AssertionError, match="Line attributes not accurate "): + multiline_geo_plot_bad.assert_lines_grouped_by_type( + multi_line_gdf, "attr" + ) + plt.close("all") + + +def test_assert_lines_grouped_by_type_passes_with_none(multiline_geo_plot): + """Test that assert passes if nothing is passed into it""" + multiline_geo_plot.assert_lines_grouped_by_type(None, None) + plt.close("all") + + +def test_assert_lines_grouped_by_type_fails_non_gdf( + multiline_geo_plot, multi_line_gdf +): + """Test that assert fails if a list is passed into it""" + with pytest.raises(ValueError, match="lines_expected is not of expected "): + multiline_geo_plot.assert_lines_grouped_by_type( + multi_line_gdf.to_numpy(), "attr" + ) + plt.close("all") + + +def test_mixed_type_passes(mixed_type_geo_plot, pd_gdf): + """Tests that points passes with a mixed type plot""" + mixed_type_geo_plot.assert_points(pd_gdf) + plt.close("all") + + +def test_get_lines_by_collection(multiline_geo_plot): + """Test that get_lines_by_collection returns the correct values""" + lines_list = [ + [ + [(1, 1), (2, 2), (3, 2), (5, 3)], + [(3, 4), (5, 7), (12, 2), (10, 5), (9, 7.5)], + [(2, 1), (3, 1), (4, 1), (5, 2)], + ] + ] + sorted_lines_list = sorted([sorted(l) for l in lines_list]) + assert sorted_lines_list == multiline_geo_plot.get_lines_by_collection() + plt.close("all") diff --git a/matplotcheck/tests/test_points.py b/matplotcheck/tests/test_points.py new file mode 100644 index 00000000..f2ead5f7 --- /dev/null +++ b/matplotcheck/tests/test_points.py @@ -0,0 +1,161 @@ +"""Tests for the vector module""" +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import pytest +import geopandas as gpd + +from matplotcheck.vector import VectorTester + +matplotlib.use("Agg") + + +@pytest.fixture +def bad_pd_gdf(pd_gdf): + """Create a point geodataframe with slightly wrong values for testing""" + return gpd.GeoDataFrame( + geometry=gpd.points_from_xy( + pd_gdf.geometry.x + 1, pd_gdf.geometry.y + 1 + ) + ) + + +@pytest.fixture +def origin_pt_gdf(pd_gdf): + """Create a point geodataframe to test assert_points when a point at the + origin of the plot (0, 0) is present in the dataframe. This checks + for a specific bug fix that was added to the assert_points function.""" + origin_pt_gdf = pd_gdf.append( + gpd.GeoDataFrame(geometry=gpd.points_from_xy([0], [0])) + ) + origin_pt_gdf.reset_index(inplace=True, drop=True) + return origin_pt_gdf + + +@pytest.fixture +def pt_geo_plot(pd_gdf): + """Create a geo plot for testing""" + _, ax = plt.subplots() + size = 0 + point_symb = {"Tree": "green", "Bush": "brown"} + + for ctype, points in pd_gdf.groupby("attr"): + color = point_symb[ctype] + label = ctype + size += 100 + points.plot(color=color, ax=ax, label=label, markersize=size) + + ax.legend(title="Legend", loc=(1.1, 0.1)) + + return VectorTester(ax) + + +@pytest.fixture +def pt_geo_plot_bad(pd_gdf): + """Create a geo plot for testing""" + _, ax = plt.subplots() + + pd_gdf.plot(ax=ax) + + return VectorTester(ax) + + +@pytest.fixture +def pt_geo_plot_origin(origin_pt_gdf, two_line_gdf): + """Create a point plot for testing assert_points with a point at the + origin""" + _, ax = plt.subplots() + + origin_pt_gdf.plot(ax=ax) + + two_line_gdf.plot(ax=ax) + + return VectorTester(ax) + + +def test_points_sorted_by_markersize_pass(pt_geo_plot, pd_gdf): + """Tests that points are plotted as different sizes based on an attribute + value passes""" + pt_geo_plot.assert_collection_sorted_by_markersize(pd_gdf, "attr") + plt.close("all") + + +def test_points_sorted_by_markersize_fail(pt_geo_plot_bad, pd_gdf): + """Tests that points are plotted as different sizes based on an attribute + value fails""" + with pytest.raises(AssertionError, match="Markersize not based on"): + pt_geo_plot_bad.assert_collection_sorted_by_markersize(pd_gdf, "attr") + plt.close("all") + + +def test_points_grouped_by_type(pt_geo_plot, pd_gdf): + """Tests that points grouped by type passes""" + pt_geo_plot.assert_points_grouped_by_type(pd_gdf, "attr") + plt.close("all") + + +def test_points_grouped_by_type_fail(pt_geo_plot_bad, pd_gdf): + """Tests that points grouped by type passes""" + with pytest.raises(AssertionError, match="Point attributes not accurate"): + pt_geo_plot_bad.assert_points_grouped_by_type(pd_gdf, "attr") + plt.close("all") + + +def test_point_geometry_pass(pt_geo_plot, pd_gdf): + """Check that the point geometry test recognizes correct points.""" + pt_geo_plot.assert_points(points_expected=pd_gdf) + plt.close("all") + + +def test_point_geometry_fail(pt_geo_plot, bad_pd_gdf): + """Check that the point geometry test recognizes incorrect points.""" + with pytest.raises(AssertionError, match="Incorrect Point Data"): + pt_geo_plot.assert_points(points_expected=bad_pd_gdf) + plt.close("all") + + +def test_assert_point_fails_list(pt_geo_plot, pd_gdf): + """ + Check that the point geometry test fails anything that's not a + GeoDataFrame + """ + list_geo = [list(pd_gdf.geometry.x), list(pd_gdf.geometry.y)] + with pytest.raises(ValueError, match="points_expected is not expected"): + pt_geo_plot.assert_points(points_expected=list_geo) + plt.close("all") + + +def test_get_points(pt_geo_plot, pd_gdf): + """Tests that get_points returns correct values""" + xy_values = pt_geo_plot.get_points() + assert list(sorted(xy_values.x)) == sorted(list(pd_gdf.geometry.x)) + assert list(sorted(xy_values.y)) == sorted(list(pd_gdf.geometry.y)) + plt.close("all") + + +def test_assert_points_custom_message(pt_geo_plot, bad_pd_gdf): + """Tests that a custom error message is passed.""" + message = "Test message" + with pytest.raises(AssertionError, match="Test message"): + pt_geo_plot.assert_points(points_expected=bad_pd_gdf, m=message) + plt.close("all") + + +def test_wrong_length_points_expected(pt_geo_plot, pd_gdf, bad_pd_gdf): + """Tests that error is thrown for incorrect length of a gdf""" + with pytest.raises(AssertionError, match="points_expected's length does "): + pt_geo_plot.assert_points(bad_pd_gdf.append(pd_gdf), "attr") + plt.close("all") + + +def test_convert_length_function_error(pt_geo_plot): + """Test that the convert length function throws an error when given + incorrect inputs""" + with pytest.raises(ValueError, match="Input array length is not: 1 or 9"): + pt_geo_plot._convert_length(np.array([1, 2, 3, 4]), 9) + + +def test_point_gdf_with_point_at_origin(pt_geo_plot_origin, origin_pt_gdf): + """Test that assert_points works when there's a point at the origin in the + gdf""" + pt_geo_plot_origin.assert_points(origin_pt_gdf) diff --git a/matplotcheck/tests/test_vector.py b/matplotcheck/tests/test_polygons.py similarity index 56% rename from matplotcheck/tests/test_vector.py rename to matplotcheck/tests/test_polygons.py index 952df150..d626de3a 100644 --- a/matplotcheck/tests/test_vector.py +++ b/matplotcheck/tests/test_polygons.py @@ -1,9 +1,32 @@ """Tests for the vector module""" -import pytest +import matplotlib import matplotlib.pyplot as plt +import pytest import geopandas as gpd +from shapely.geometry import Polygon + from matplotcheck.vector import VectorTester +matplotlib.use("Agg") + + +@pytest.fixture +def multi_polygon_gdf(basic_polygon): + """ + A GeoDataFrame containing the basic polygon geometry. + Returns + ------- + GeoDataFrame containing the basic_polygon polygon. + """ + poly_a = Polygon([(3, 5), (2, 3.25), (5.25, 6), (2.25, 2), (2, 2)]) + gdf = gpd.GeoDataFrame( + [1, 2], geometry=[poly_a, basic_polygon], crs="epsg:4326", + ) + multi_gdf = gpd.GeoDataFrame( + geometry=gpd.GeoSeries(gdf.unary_union), crs="epsg:4326" + ) + return multi_gdf + @pytest.fixture def poly_geo_plot(basic_polygon_gdf): @@ -11,38 +34,18 @@ def poly_geo_plot(basic_polygon_gdf): _, ax = plt.subplots() basic_polygon_gdf.plot(ax=ax) - ax.set_title("My Plot Title", fontsize=30) - ax.set_xlabel("x label") - ax.set_ylabel("y label") - axis = plt.gca() - - return VectorTester(axis) + return VectorTester(ax) @pytest.fixture -def point_geo_plot(pd_gdf): - """Create a point plot for testing""" +def multi_poly_geo_plot(multi_polygon_gdf): + """Create a mutlipolygon vector tester object.""" _, ax = plt.subplots() - pd_gdf.plot(ax=ax) - ax.set_title("My Plot Title", fontsize=30) - ax.set_xlabel("x label") - ax.set_ylabel("y label") - - axis = plt.gca() + multi_polygon_gdf.plot(ax=ax) - return VectorTester(axis) - - -@pytest.fixture -def bad_pd_gdf(pd_gdf): - """Create a point geodataframe with slightly wrong values for testing""" - return gpd.GeoDataFrame( - geometry=gpd.points_from_xy( - pd_gdf.geometry.x + 1, pd_gdf.geometry.y + 1 - ) - ) + return VectorTester(ax) def test_list_of_polygons_check(poly_geo_plot, basic_polygon): @@ -50,34 +53,34 @@ def test_list_of_polygons_check(poly_geo_plot, basic_polygon): x, y = basic_polygon.exterior.coords.xy poly_list = [list(zip(x, y))] poly_geo_plot.assert_polygons(poly_list) - plt.close() + plt.close("all") def test_polygon_geodataframe_check(poly_geo_plot, basic_polygon_gdf): """Check that the polygon assert works with a polygon geodataframe""" poly_geo_plot.assert_polygons(basic_polygon_gdf) - plt.close() + plt.close("all") def test_empty_list_polygon_check(poly_geo_plot): """Check that the polygon assert fails an empty list.""" with pytest.raises(ValueError, match="Empty list or GeoDataFrame "): poly_geo_plot.assert_polygons([]) - plt.close() + plt.close("all") def test_empty_list_entry_polygon_check(poly_geo_plot): """Check that the polygon assert fails a list with an empty entry.""" with pytest.raises(ValueError, match="Empty list or GeoDataFrame "): poly_geo_plot.assert_polygons([[]]) - plt.close() + plt.close("all") def test_empty_gdf_polygon_check(poly_geo_plot): """Check that the polygon assert fails an empty GeoDataFrame.""" with pytest.raises(ValueError, match="Empty list or GeoDataFrame "): poly_geo_plot.assert_polygons(gpd.GeoDataFrame([])) - plt.close() + plt.close("all") def test_polygon_dec_check(poly_geo_plot, basic_polygon): @@ -88,7 +91,7 @@ def test_polygon_dec_check(poly_geo_plot, basic_polygon): x, y = basic_polygon.exterior.coords.xy poly_list = [[(x[0] + 0.1, x[1]) for x in list(zip(x, y))]] poly_geo_plot.assert_polygons(poly_list, dec=1) - plt.close() + plt.close("all") def test_polygon_dec_check_fail(poly_geo_plot, basic_polygon): @@ -100,7 +103,7 @@ def test_polygon_dec_check_fail(poly_geo_plot, basic_polygon): x, y = basic_polygon.exterior.coords.xy poly_list = [(x[0] + 0.5, x[1]) for x in list(zip(x, y))] poly_geo_plot.assert_polygons(poly_list, dec=1) - plt.close() + plt.close("all") def test_polygon_custom_fail_message(poly_geo_plot, basic_polygon): @@ -109,39 +112,10 @@ def test_polygon_custom_fail_message(poly_geo_plot, basic_polygon): x, y = basic_polygon.exterior.coords.xy poly_list = [(x[0] + 0.5, x[1]) for x in list(zip(x, y))] poly_geo_plot.assert_polygons(poly_list, m="Test Message") - plt.close() - - -def test_point_geometry_pass(point_geo_plot, pd_gdf): - """Check that the point geometry test recognizes correct points.""" - point_geo_plot.assert_points(points_expected=pd_gdf) - - -def test_point_geometry_fail(point_geo_plot, bad_pd_gdf): - """Check that the point geometry test recognizes incorrect points.""" - with pytest.raises(AssertionError, match="Incorrect Point Data"): - point_geo_plot.assert_points(points_expected=bad_pd_gdf) - - -def test_assert_point_fails_list(point_geo_plot, pd_gdf): - """ - Check that the point geometry test fails anything that's not a - GeoDataFrame - """ - list_geo = [list(pd_gdf.geometry.x), list(pd_gdf.geometry.y)] - with pytest.raises(ValueError, match="points_expected is not expected"): - point_geo_plot.assert_points(points_expected=list_geo) - - -def test_get_points(point_geo_plot, pd_gdf): - """Tests that get_points returns correct values""" - xy_values = point_geo_plot.get_points() - assert list(sorted(xy_values.x)) == sorted(list(pd_gdf.geometry.x)) - assert list(sorted(xy_values.y)) == sorted(list(pd_gdf.geometry.y)) + plt.close("all") -def test_assert_points_custom_message(point_geo_plot, bad_pd_gdf): - """Tests that a custom error message is passed.""" - message = "Test message" - with pytest.raises(AssertionError, match="Test message"): - point_geo_plot.assert_points(points_expected=bad_pd_gdf, m=message) +def test_multi_polygon_pass(multi_poly_geo_plot, multi_polygon_gdf): + """Check a multipolygon passes""" + multi_poly_geo_plot.assert_polygons(multi_polygon_gdf) + plt.close("all")