diff --git a/CHANGELOG.md b/CHANGELOG.md index b061a7af..86f7f067 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +* Created functions to test point geometries in VectorTester (@nkorinek, #176) * made `assert_string_contains()` accept correct strings with spaces in them (@nkorinek, #182) * added contributors file and updated README to remove that information (@nkorinek, #121) diff --git a/matplotcheck/tests/test_vector.py b/matplotcheck/tests/test_vector.py index e65a0611..952df150 100644 --- a/matplotcheck/tests/test_vector.py +++ b/matplotcheck/tests/test_vector.py @@ -20,6 +20,31 @@ def poly_geo_plot(basic_polygon_gdf): return VectorTester(axis) +@pytest.fixture +def point_geo_plot(pd_gdf): + """Create a point plot for testing""" + _, ax = plt.subplots() + + pd_gdf.plot(ax=ax) + ax.set_title("My Plot Title", fontsize=30) + ax.set_xlabel("x label") + ax.set_ylabel("y label") + + axis = plt.gca() + + return VectorTester(axis) + + +@pytest.fixture +def bad_pd_gdf(pd_gdf): + """Create a point geodataframe with slightly wrong values for testing""" + return gpd.GeoDataFrame( + geometry=gpd.points_from_xy( + pd_gdf.geometry.x + 1, pd_gdf.geometry.y + 1 + ) + ) + + def test_list_of_polygons_check(poly_geo_plot, basic_polygon): """Check that the polygon assert works with a list of polygons.""" x, y = basic_polygon.exterior.coords.xy @@ -85,3 +110,38 @@ def test_polygon_custom_fail_message(poly_geo_plot, basic_polygon): poly_list = [(x[0] + 0.5, x[1]) for x in list(zip(x, y))] poly_geo_plot.assert_polygons(poly_list, m="Test Message") plt.close() + + +def test_point_geometry_pass(point_geo_plot, pd_gdf): + """Check that the point geometry test recognizes correct points.""" + point_geo_plot.assert_points(points_expected=pd_gdf) + + +def test_point_geometry_fail(point_geo_plot, bad_pd_gdf): + """Check that the point geometry test recognizes incorrect points.""" + with pytest.raises(AssertionError, match="Incorrect Point Data"): + point_geo_plot.assert_points(points_expected=bad_pd_gdf) + + +def test_assert_point_fails_list(point_geo_plot, pd_gdf): + """ + Check that the point geometry test fails anything that's not a + GeoDataFrame + """ + list_geo = [list(pd_gdf.geometry.x), list(pd_gdf.geometry.y)] + with pytest.raises(ValueError, match="points_expected is not expected"): + point_geo_plot.assert_points(points_expected=list_geo) + + +def test_get_points(point_geo_plot, pd_gdf): + """Tests that get_points returns correct values""" + xy_values = point_geo_plot.get_points() + assert list(sorted(xy_values.x)) == sorted(list(pd_gdf.geometry.x)) + assert list(sorted(xy_values.y)) == sorted(list(pd_gdf.geometry.y)) + + +def test_assert_points_custom_message(point_geo_plot, bad_pd_gdf): + """Tests that a custom error message is passed.""" + message = "Test message" + with pytest.raises(AssertionError, match="Test message"): + point_geo_plot.assert_points(points_expected=bad_pd_gdf, m=message) diff --git a/matplotcheck/vector.py b/matplotcheck/vector.py index f91f8681..ab77c8da 100644 --- a/matplotcheck/vector.py +++ b/matplotcheck/vector.py @@ -21,63 +21,6 @@ def __init__(self, ax): """Initialize the vector tester""" super(VectorTester, self).__init__(ax) - def assert_legend_no_overlay_content( - self, m="Legend overlays plot contents" - ): - """Asserts that each legend does not overlay plot contents with error - message m - - Parameters - --------- - m: string - error message if assertion is not met - """ - plot_extent = self.ax.get_window_extent().get_points() - legends = self.get_legends() - for leg in legends: - leg_extent = leg.get_window_extent().get_points() - legend_left = leg_extent[1][0] < plot_extent[0][0] - legend_right = leg_extent[0][0] > plot_extent[1][0] - legend_below = leg_extent[1][1] < plot_extent[0][1] - assert legend_left or legend_right or legend_below, m - - def _legends_overlap(self, b1, b2): - """Helper function for assert_no_legend_overlap - Boolean value if points of window extents for b1 and b2 overlap - - Parameters - ---------- - b1: bounding box of window extents - b2: bounding box of window extents - - Returns - ------ - boolean value that says if bounding boxes b1 and b2 overlap - """ - x_overlap = (b1[0][0] <= b2[1][0] and b1[0][0] >= b2[0][0]) or ( - b1[1][0] <= b2[1][0] and b1[1][0] >= b2[0][0] - ) - y_overlap = (b1[0][1] <= b2[1][1] and b1[0][1] >= b2[0][1]) or ( - b1[1][1] <= b2[1][1] and b1[1][1] >= b2[0][1] - ) - return x_overlap and y_overlap - - def assert_no_legend_overlap(self, m="Legends overlap eachother"): - """Asserts there are no two legends in Axes ax that overlap each other - with error message m - - Parameters - ---------- - m: string error message if assertion is not met - """ - legends = self.get_legends() - n = len(legends) - for i in range(n - 1): - leg_extent1 = legends[i].get_window_extent().get_points() - for j in range(i + 1, n): - leg_extent2 = legends[j].get_window_extent().get_points() - assert not self._legends_overlap(leg_extent1, leg_extent2), m - """ Check Data """ def _convert_length(self, arr, n): @@ -160,7 +103,7 @@ def get_points_by_attributes(self): return sorted([sorted(p) for p in points_grouped]) def assert_points_grouped_by_type( - self, data_exp, sort_column, m="Point attribtues not accurate by type" + self, data_exp, sort_column, m="Point attributes not accurate by type" ): """Asserts that the points on Axes ax display attributes based on their type with error message m @@ -196,27 +139,28 @@ def sort_collection_by_markersize(self): """ df = pd.DataFrame(columns=("x", "y", "markersize")) for c in self.ax.collections: - offsets, markersizes = c.get_offsets(), c.get_sizes() - x_data, y_data = ( - [offset[0] for offset in offsets], - [offset[1] for offset in offsets], - ) - if len(markersizes) == 1: - markersize = [markersizes[0]] * len(offsets) - df2 = pd.DataFrame( - {"x": x_data, "y": y_data, "markersize": markersize} - ) - df = df.append(df2) - elif len(markersizes) == len(offsets): - df2 = pd.DataFrame( - {"x": x_data, "y": y_data, "markersize": markersizes} + if isinstance(c, matplotlib.collections.PathCollection): + offsets, markersizes = c.get_offsets(), c.get_sizes() + x_data, y_data = ( + [offset[0] for offset in offsets], + [offset[1] for offset in offsets], ) - df = df.append(df2) + if len(markersizes) == 1: + markersize = [markersizes[0]] * len(offsets) + df2 = pd.DataFrame( + {"x": x_data, "y": y_data, "markersize": markersize} + ) + df = df.append(df2) + elif len(markersizes) == len(offsets): + df2 = pd.DataFrame( + {"x": x_data, "y": y_data, "markersize": markersizes} + ) + df = df.append(df2) df = df.sort_values(by="markersize").reset_index(drop=True) return df def assert_collection_sorted_by_markersize(self, df_expected, sort_column): - """Asserts a collection of points vary in size by column expresse din + """Asserts a collection of points vary in size by column expressed in sort_column Parameters @@ -244,7 +188,69 @@ def assert_collection_sorted_by_markersize(self, df_expected, sort_column): err_msg="Markersize not based on {0} values".format(sort_column), ) - """Check lines""" + def get_points(self): + """Returns a Pandas dataframe with all x, y values for points on axis. + + Returns + ------- + output: DataFrame with columns 'x' and 'y'. Each row represents one + points coordinates. + """ + points = self.get_xy(points_only=True).sort_values(by=["x", "y"]) + points.reset_index(inplace=True, drop=True) + return points + + def assert_points(self, points_expected, m="Incorrect Point Data"): + """ + Asserts the point data in Axes ax is equal to points_expected data + with error message m. + If points_expected not a GeoDataFrame, test fails. + + Parameters + ---------- + points_expected : GeoDataFrame + GeoDataFrame with the expected points for the axis. + m : string (default = "Incorrect Point Data") + String error message if assertion is not met. + """ + if isinstance(points_expected, gpd.geodataframe.GeoDataFrame): + points = self.get_points() + xy_expected = pd.DataFrame(columns=["x", "y"]) + xy_expected["x"] = points_expected.geometry.x + xy_expected["y"] = points_expected.geometry.y + xy_expected = xy_expected.sort_values(by=["x", "y"]) + xy_expected.reset_index(inplace=True, drop=True) + # Fix for failure if more than points were plotted in matplotlib + if len(points) != len(xy_expected): + # Checks if there are extra 0, 0 coords in the DataFrame + # returned from self.get_points and removes them. + points_zeros = (points["x"] == 0) & (points["y"] == 0) + if points_zeros.any(): + expected_zeros = (xy_expected["x"] == 0) & ( + xy_expected["y"] == 0 + ) + keep = expected_zeros.sum() + zeros_index_vals = points_zeros.index[ + points_zeros.tolist() + ] + for i in range(keep): + points_zeros.at[zeros_index_vals[i]] = False + points = points[~points_zeros].reset_index(drop=True) + else: + raise AssertionError( + "points_expected's length does not match the stored" + "data's length." + ) + try: + pd.testing.assert_frame_equal(left=points, right=xy_expected) + except AssertionError: + raise AssertionError(m) + else: + raise ValueError( + "points_expected is not expected type: GeoDataFrame" + ) + + # Lines def _convert_multilines(self, df, column_title): """Helper function for get_lines_by_attribute @@ -449,7 +455,7 @@ def assert_lines_grouped_by_type( grouped_exp = sorted([sorted(l) for l in grouped_exp]) plt.close(fig) np.testing.assert_equal(groups, grouped_exp, m) - elif not lines_expected: + elif lines_expected is None: pass else: raise ValueError(