diff --git a/CHANGELOG.md b/CHANGELOG.md index 046dccfa..415c27fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] * Created a vignette covering the testing of histograms (@ryla5068, #149) +* Allowed `assert_polygons()` to accept GeoDataFrames and added tests (@nkorinek, #175) ## [0.1.1] * Added test for bin heights of histograms (@ryla5068, #124) diff --git a/dev-requirements.txt b/dev-requirements.txt index 43d610df..635fa63a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -10,3 +10,4 @@ codecov==2.0.15 setuptools==45.2.0 pre-commit==1.20.0 pip==19.0.3 +descartes==1.1.0 diff --git a/matplotcheck/tests/conftest.py b/matplotcheck/tests/conftest.py index 4931fd81..278b3380 100644 --- a/matplotcheck/tests/conftest.py +++ b/matplotcheck/tests/conftest.py @@ -2,6 +2,7 @@ import pytest import pandas as pd import geopandas as gpd +from shapely.geometry import Polygon import numpy as np import matplotlib.pyplot as plt from matplotcheck.base import PlotTester @@ -42,6 +43,31 @@ def pd_gdf(): return gdf +@pytest.fixture +def basic_polygon(): + """ + A square polygon spanning (2, 2) to (4.25, 4.25) in x and y directions. + Borrowed from rasterio/tests/conftest.py. + Returns + ------- + dict: GeoJSON-style geometry object. + Coordinates are in grid coordinates (Affine.identity()). + """ + return Polygon([(2, 2), (2, 4.25), (4.25, 4.25), (4.25, 2), (2, 2)]) + + +@pytest.fixture +def basic_polygon_gdf(basic_polygon): + """ + A GeoDataFrame containing the basic polygon geometry. + Returns + ------- + GeoDataFrame containing the basic_polygon polygon. + """ + gdf = gpd.GeoDataFrame(geometry=[basic_polygon], crs={"init": "epsg:4326"}) + return gdf + + @pytest.fixture def pd_xlabels(): """Create a DataFrame which uses the column labels as x-data.""" diff --git a/matplotcheck/tests/test_vector.py b/matplotcheck/tests/test_vector.py new file mode 100644 index 00000000..e65a0611 --- /dev/null +++ b/matplotcheck/tests/test_vector.py @@ -0,0 +1,87 @@ +"""Tests for the vector module""" +import pytest +import matplotlib.pyplot as plt +import geopandas as gpd +from matplotcheck.vector import VectorTester + + +@pytest.fixture +def poly_geo_plot(basic_polygon_gdf): + """Create a polygon vector tester object.""" + _, ax = plt.subplots() + + basic_polygon_gdf.plot(ax=ax) + ax.set_title("My Plot Title", fontsize=30) + ax.set_xlabel("x label") + ax.set_ylabel("y label") + + axis = plt.gca() + + return VectorTester(axis) + + +def test_list_of_polygons_check(poly_geo_plot, basic_polygon): + """Check that the polygon assert works with a list of polygons.""" + x, y = basic_polygon.exterior.coords.xy + poly_list = [list(zip(x, y))] + poly_geo_plot.assert_polygons(poly_list) + plt.close() + + +def test_polygon_geodataframe_check(poly_geo_plot, basic_polygon_gdf): + """Check that the polygon assert works with a polygon geodataframe""" + poly_geo_plot.assert_polygons(basic_polygon_gdf) + plt.close() + + +def test_empty_list_polygon_check(poly_geo_plot): + """Check that the polygon assert fails an empty list.""" + with pytest.raises(ValueError, match="Empty list or GeoDataFrame "): + poly_geo_plot.assert_polygons([]) + plt.close() + + +def test_empty_list_entry_polygon_check(poly_geo_plot): + """Check that the polygon assert fails a list with an empty entry.""" + with pytest.raises(ValueError, match="Empty list or GeoDataFrame "): + poly_geo_plot.assert_polygons([[]]) + plt.close() + + +def test_empty_gdf_polygon_check(poly_geo_plot): + """Check that the polygon assert fails an empty GeoDataFrame.""" + with pytest.raises(ValueError, match="Empty list or GeoDataFrame "): + poly_geo_plot.assert_polygons(gpd.GeoDataFrame([])) + plt.close() + + +def test_polygon_dec_check(poly_geo_plot, basic_polygon): + """ + Check that the polygon assert passes when the polygon is off by less than + the maximum decimal precision. + """ + x, y = basic_polygon.exterior.coords.xy + poly_list = [[(x[0] + 0.1, x[1]) for x in list(zip(x, y))]] + poly_geo_plot.assert_polygons(poly_list, dec=1) + plt.close() + + +def test_polygon_dec_check_fail(poly_geo_plot, basic_polygon): + """ + Check that the polygon assert fails when the polygon is off by more than + the maximum decimal precision. + """ + with pytest.raises(AssertionError, match="Incorrect Polygon"): + x, y = basic_polygon.exterior.coords.xy + poly_list = [(x[0] + 0.5, x[1]) for x in list(zip(x, y))] + poly_geo_plot.assert_polygons(poly_list, dec=1) + plt.close() + + +def test_polygon_custom_fail_message(poly_geo_plot, basic_polygon): + """Check that the corrct error message is raised when polygons fail""" + with pytest.raises(AssertionError, match="Test Message"): + x, y = basic_polygon.exterior.coords.xy + poly_list = [(x[0] + 0.5, x[1]) for x in list(zip(x, y))] + poly_geo_plot.assert_polygons(poly_list, m="Test Message") + plt.close() diff --git a/matplotcheck/vector.py b/matplotcheck/vector.py index acc44459..1996b4bc 100644 --- a/matplotcheck/vector.py +++ b/matplotcheck/vector.py @@ -80,7 +80,7 @@ def assert_no_legend_overlap(self, m="Legends overlap eachother"): def _convert_length(self, arr, n): """ helper function for 'get_points_by_attributes' and 'get_lines_by_attributes' - takes an array of either legnth 1 or n. + takes an array of either legnth 1 or n. If array is length 1: array of array's only element repeating n times is returned If array is length n: original array is returned Else: function raises value error @@ -106,7 +106,7 @@ def _convert_length(self, arr, n): ) def get_points_by_attributes(self): - """Returns a sorted list of lists where each list contains tuples of xycoords for points of + """Returns a sorted list of lists where each list contains tuples of xycoords for points of the same attributes: color, marker, and markersize Returns @@ -287,11 +287,11 @@ def _convert_linestyle(self, ls): return (ls[0], onoffseq) def get_lines(self): - """Returns a dataframe with all lines on ax + """Returns a dataframe with all lines on ax Returns ------- - output: DataFrame with column 'lines'. Each row represents one line segment. + output: DataFrame with column 'lines'. Each row represents one line segment. Its value in 'lines' is a list of tuples representing the line segement. """ lines = [ @@ -472,15 +472,29 @@ def assert_polygons( self, polygons_expected, dec=None, m="Incorrect Polygon Data" ): """Asserts the polygon data in Axes ax is equal to polygons_expected to decimal place dec with error message m - If polygons_expected is am empty list or None, assertion is passed - - Parameters - ---------- - polygons_expected: list of polygons expected to be founds on Axes ax - dec: int stating the desired decimal precision. If None, polygons must be exact - m: string error message if assertion is not met + If polygons_expected is am empty list or None, assertion is passed. + + Parameters + ---------- + polygons_expected : List or GeoDataFrame + List of polygons expected to be founds on Axes ax or a GeoDataFrame + containing the expected polygons. + dec : int (Optional) + Int stating the desired decimal precision. If None, polygons must + be exact. + m : string (default = "Incorrect Polygon Data") + String error message if assertion is not met. """ - if polygons_expected: + if len(polygons_expected) != 0: + if isinstance(polygons_expected, list): + if len(polygons_expected[0]) == 0: + raise ValueError( + "Empty list or GeoDataFrame passed into assert_polygons." + ) + if isinstance(polygons_expected, gpd.geodataframe.GeoDataFrame): + polygons_expected = self._convert_multipolygons( + polygons_expected["geometry"] + ) polygons = self.get_polygons() if dec: assert len(polygons_expected) == len(polygons), m @@ -494,3 +508,7 @@ def assert_polygons( ) else: np.testing.assert_equal(polygons, sorted(polygons_expected), m) + else: + raise ValueError( + "Empty list or GeoDataFrame passed into assert_polygons." + )