diff --git a/CHANGELOG.md b/CHANGELOG.md index 415c27fc..eca75ef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] * Created a vignette covering the testing of histograms (@ryla5068, #149) +* Created `get_plot_image()` function for the RasterTester object (@nkorinek, #192) * Allowed `assert_polygons()` to accept GeoDataFrames and added tests (@nkorinek, #175) ## [0.1.1] diff --git a/matplotcheck/raster.py b/matplotcheck/raster.py index 9e8aa5f0..4f62a8ce 100644 --- a/matplotcheck/raster.py +++ b/matplotcheck/raster.py @@ -164,6 +164,24 @@ def assert_legend_accuracy_classified_image( ### IMAGE TESTS/HELPER FUNCTIONS ### + def get_plot_image(self): + """Returns images stored on the Axes object as a list of numpy arrays. + + Returns + ------- + im_data: List + Numpy array of images stored on Axes object. + """ + im_data = [] + if self.ax.get_images(): + im_data = self.ax.get_images()[0].get_array() + assert list(im_data), "No Image Displayed" + + # If image array has 3 dims (e.g. rgb image), remove alpha channel + if len(im_data.shape) == 3: + im_data = im_data[:, :, :3] + return im_data + def assert_image( self, im_expected, im_classified=False, m="Incorrect Image Displayed" ): @@ -171,12 +189,14 @@ def assert_image( Parameters ---------- - im_expected: array containing the expected image data - im_classified: boolean, set to True image has been classified. - Since classified images values can be reversed or shifted and still - produce the same image, setting this to True will allow those - changes. - m: string error message if assertion is not met + im_expected: Numpy Array + Array containing the expected image data. + im_classified: boolean + Set to True image has been classified. Since classified images + values can be reversed or shifted and still produce the same image, + setting this to True will allow those changes. + m: string + String error message if assertion is not met. Returns ---------- diff --git a/matplotcheck/tests/test_raster.py b/matplotcheck/tests/test_raster.py index ab59e62e..dd623660 100644 --- a/matplotcheck/tests/test_raster.py +++ b/matplotcheck/tests/test_raster.py @@ -317,3 +317,9 @@ def test_raster_assert_image_fullscreen_blank(raster_plt_blank): with pytest.raises(AssertionError, match="No image found on axes"): raster_plt_blank.assert_image_full_screen() plt.close() + + +def test_get_plot_images(raster_plt_rgb): + """get_plot_image should get correct image from ax object""" + ax_im = raster_plt_rgb.get_plot_image() + raster_plt_rgb.assert_image(ax_im) diff --git a/matplotcheck/vector.py b/matplotcheck/vector.py index 1996b4bc..10610602 100644 --- a/matplotcheck/vector.py +++ b/matplotcheck/vector.py @@ -11,11 +11,11 @@ class VectorTester(PlotTester): """A PlotTester for spatial vector plots. - Parameters - ---------- - ax: ```matplotlib.axes.Axes``` object + Parameters + ---------- + ax: ```matplotlib.axes.Axes``` object - """ + """ def __init__(self, ax): """Initialize the vector tester""" @@ -26,10 +26,11 @@ def assert_legend_no_overlay_content( ): """Asserts that each legend does not overlay plot contents with error message m - Parameters - --------- - m: string error message if assertion is not met - """ + Parameters + --------- + m: string + error message if assertion is not met + """ plot_extent = self.ax.get_window_extent().get_points() legends = self.get_legends() for leg in legends: @@ -41,17 +42,17 @@ def assert_legend_no_overlay_content( def _legends_overlap(self, b1, b2): """Helper function for assert_no_legend_overlap - Boolean value if points of window extents for b1 and b2 overlap + Boolean value if points of window extents for b1 and b2 overlap - parmeters - --------- - b1: bounding box of window extents - b2: bounding box of window extents + Parameters + ---------- + b1: bounding box of window extents + b2: bounding box of window extents - Returns - ------ - boolean value that says if bounding boxes b1 and b2 overlap - """ + Returns + ------ + boolean value that says if bounding boxes b1 and b2 overlap + """ x_overlap = (b1[0][0] <= b2[1][0] and b1[0][0] >= b2[0][0]) or ( b1[1][0] <= b2[1][0] and b1[1][0] >= b2[0][0] ) @@ -63,10 +64,10 @@ def _legends_overlap(self, b1, b2): def assert_no_legend_overlap(self, m="Legends overlap eachother"): """Asserts there are no two legends in Axes ax that overlap each other with error message m - Parameters - ---------- - m: string error message if assertion is not met - """ + Parameters + ---------- + m: string error message if assertion is not met + """ legends = self.get_legends() n = len(legends) for i in range(n - 1): @@ -80,20 +81,22 @@ def assert_no_legend_overlap(self, m="Legends overlap eachother"): def _convert_length(self, arr, n): """ helper function for 'get_points_by_attributes' and 'get_lines_by_attributes' - takes an array of either legnth 1 or n. - If array is length 1: array of array's only element repeating n times is returned - If array is length n: original array is returned - Else: function raises value error - - Parameters - ---------- - arr: array of either length 1 or n - n: length of return array - - Returns - ------- - array of length n - """ + takes an array of either length 1 or n. + If array is length 1: array of array's only element repeating n times is returned + If array is length n: original array is returned + Else: function raises value error + + Parameters + ---------- + arr: array + A numpy array of either length 1 or n + n: int + length of return array + + Returns + ------- + array of length n + """ if len(arr) == 1: return list(arr) * n elif len(arr) == n: @@ -107,13 +110,13 @@ def _convert_length(self, arr, n): def get_points_by_attributes(self): """Returns a sorted list of lists where each list contains tuples of xycoords for points of - the same attributes: color, marker, and markersize + the same attributes: color, marker, and markersize - Returns - ------- - sorted list where each list represents all points with the same color. - each point is represented by a tuple with its coordinates. - """ + Returns + ------- + sorted list where each list represents all points with the same color. + each point is represented by a tuple with its coordinates. + """ points_dataframe = pd.DataFrame( columns=["offset", "color", "msize", "mstyle"] ) @@ -160,16 +163,16 @@ def assert_points_grouped_by_type( self, data_exp, sort_column, m="Point attribtues not accurate by type" ): """Asserts that the points on Axes ax display attributes based on their type with error message m - attributes tested are: color, marker, and markersize + attributes tested are: color, marker, and markersize - Parameters - ---------- - data_exp: Geopandas Dataframe with Point objects in column 'geometry' - an additional column with title sort_column, denotes a category for each point - sort_column: string of column label in dataframe data_exp. - this column contains values expressing which points belong to which group - m: string error message if assertion is not met - """ + Parameters + ---------- + data_exp: Geopandas Dataframe with Point objects in column 'geometry' + an additional column with title sort_column, denotes a category for each point + sort_column: string of column label in dataframe data_exp. + this column contains values expressing which points belong to which group + m: string error message if assertion is not met + """ groups = self.get_points_by_attributes() grouped_exp = [ @@ -183,10 +186,10 @@ def assert_points_grouped_by_type( def sort_collection_by_markersize(self): """ Returns a pandas dataframe of points in collections on Axes ax. - Returns - -------- - pandas dataframe with columns x, y, point_size. Each row reprsents a point on Axes ax with location x,y and markersize pointsize - """ + Returns + -------- + pandas dataframe with columns x, y, point_size. Each row reprsents a point on Axes ax with location x,y and markersize pointsize + """ df = pd.DataFrame(columns=("x", "y", "markersize")) for c in self.ax.collections: offsets, markersizes = c.get_offsets(), c.get_sizes() @@ -211,12 +214,12 @@ def sort_collection_by_markersize(self): def assert_collection_sorted_by_markersize(self, df_expected, sort_column): """Asserts a collection of points vary in size by column expresse din sort_column - Parameters - ---------- - df_expected: geopandas dataframe with geometry column of expected point locations - sort_column: column title from df_expected that points are expected to be sorted by - if None, assertion is passed - """ + Parameters + ---------- + df_expected: geopandas dataframe with geometry column of expected point locations + sort_column: column title from df_expected that points are expected to be sorted by + if None, assertion is passed + """ df = self.sort_collection_by_markersize() df_expected = df_expected.sort_values(by=sort_column).reset_index( drop=True @@ -238,20 +241,20 @@ def assert_collection_sorted_by_markersize(self, df_expected, sort_column): def _convert_multilines(self, df, column_title): """Helper function for get_lines_by_attribute - converts a pandas dataframe containing a column of LineString and MultiLinestring objects - to a pandas dataframe where each row represents a single line. Line segment values are converted - to a list of tuples. - - Parameters - --------- - df: pandas Dataframe containing a column of LineString and MultiLinestring objects - column_title: string of column title which holds LineString and MultLinestring objects - - Returns - ------- - Dataframe where each row repsrents a single line. - Line segments values are converted to a list of tuples in column column_title - """ + converts a pandas dataframe containing a column of LineString and MultiLinestring objects + to a pandas dataframe where each row represents a single line. Line segment values are converted + to a list of tuples. + + Parameters + --------- + df: pandas Dataframe containing a column of LineString and MultiLinestring objects + column_title: string of column title which holds LineString and MultLinestring objects + + Returns + ------- + Dataframe where each row repsrents a single line. + Line segments values are converted to a list of tuples in column column_title + """ dfout = df.copy() for i, row in dfout.iterrows(): seg = row[column_title] @@ -271,16 +274,16 @@ def _convert_multilines(self, df, column_title): def _convert_linestyle(self, ls): """helper function for get_lines_by_attributes. - converts linestyle to a tuple of (offset, onoffseq) to get hashable datatypes + converts linestyle to a tuple of (offset, onoffseq) to get hashable datatypes - Parameters - ---------- - ls: linesytle from a LineCollection retreived by get_linestyle() + Parameters + ---------- + ls: linesytle from a LineCollection retreived by get_linestyle() - Returns - ------- - tuple containing (offset, onoffseq) of linestyle - """ + Returns + ------- + tuple containing (offset, onoffseq) of linestyle + """ onoffseq = ls[1] if onoffseq: onoffseq = tuple(ls[1]) @@ -289,11 +292,11 @@ def _convert_linestyle(self, ls): def get_lines(self): """Returns a dataframe with all lines on ax - Returns - ------- - output: DataFrame with column 'lines'. Each row represents one line segment. - Its value in 'lines' is a list of tuples representing the line segement. - """ + Returns + ------- + output: DataFrame with column 'lines'. Each row represents one line segment. + Its value in 'lines' is a list of tuples representing the line segement. + """ lines = [ [tuple(coords) for coords in seg] for c in self.ax.collections @@ -305,10 +308,10 @@ def get_lines(self): def get_lines_by_collection(self): """Returns a sorted list of list where each list contains line segments from the same collections - Returns - ------- - sorted list where each list represents all lines from the same collection - """ + Returns + ------- + sorted list where each list represents all lines from the same collection + """ lines_grouped = [ [[tuple(coords) for coords in seg] for seg in c.get_segments()] for c in self.ax.collections @@ -318,12 +321,12 @@ def get_lines_by_collection(self): def get_lines_by_attributes(self): """Returns a sorted list of lists where each list contains line segments of the same attributes: - color, linewidth, and linestyle + color, linewidth, and linestyle - Returns - ------ - sorted list where each list represents all lines with the same attributes - """ + Returns + ------ + sorted list where each list represents all lines with the same attributes + """ lines_dataframe = pd.DataFrame( columns=["seg", "color", "lwidth", "lstyle"] ) @@ -366,13 +369,13 @@ def get_lines_by_attributes(self): def assert_lines(self, lines_expected, m="Incorrect Line Data"): """Asserts the line data in Axes ax is equal to lines_expected with error message m. - If line_expected is None or an empty list, assertion is passed + If line_expected is None or an empty list, assertion is passed - Parameters - ---------- - lines_expected: Geopandas Dataframe with a geometry column consisting of MultilineString and LineString objects - m: string error message if assertion is not met - """ + Parameters + ---------- + lines_expected: Geopandas Dataframe with a geometry column consisting of MultilineString and LineString objects + m: string error message if assertion is not met + """ if type(lines_expected) == gpd.geodataframe.GeoDataFrame: lines_expected = lines_expected[ lines_expected["geometry"].is_empty == False @@ -398,14 +401,14 @@ def assert_lines_grouped_by_type( m="Line attributes not accurate by type", ): """Asserts that the lines on Axes ax display like attributes based on their type with error message m - attributes tested are: color, linewidth, linestyle - - Parameters - ---------- - lines_expected: Geopandas Dataframe with geometry column consisting of MultiLineString and LineString objects - sort_column: string of column title in lines_expected that contains types lines are expected to be grouped by - m: string error message if assertion is not met - """ + attributes tested are: color, linewidth, linestyle + + Parameters + ---------- + lines_expected: Geopandas Dataframe with geometry column consisting of MultiLineString and LineString objects + sort_column: string of column title in lines_expected that contains types lines are expected to be grouped by + m: string error message if assertion is not met + """ if type(lines_expected) == gpd.geodataframe.GeoDataFrame: groups = self.get_lines_by_attributes() lines_expected = lines_expected[ @@ -434,10 +437,10 @@ def assert_lines_grouped_by_type( def get_polygons(self): """Returns all polygons on Axes ax as a sorted list of polygons where each polygon is a list of coord tuples - Returns - ------- - output: sorted list of polygons. Each polygon is a list tuples. Ecah tuples is a coordinate. - """ + Returns + ------- + output: sorted list of polygons. Each polygon is a list tuples. Ecah tuples is a coordinate. + """ output = [ [tuple(coords) for coords in path.vertices] for c in self.ax.collections @@ -448,17 +451,17 @@ def get_polygons(self): def _convert_multipolygons(self, series): """Helper function for assert_polygons - converts a pandas series of Polygon and MultiPolygon objects to a list of lines, - where each line is a list of coord tuples for the exterior + converts a pandas series of Polygon and MultiPolygon objects to a list of lines, + where each line is a list of coord tuples for the exterior - Parameters - ---------- - series: series where each entry is a Polygon or MultiPolygon + Parameters + ---------- + series: series where each entry is a Polygon or MultiPolygon - Returns - ------- - list of lines where each line is a list of coord tuples for the exterior polygon - """ + Returns + ------- + list of lines where each line is a list of coord tuples for the exterior polygon + """ output = [] for entry in series: if type(entry) == shapely.geometry.multipolygon.MultiPolygon: @@ -472,19 +475,19 @@ def assert_polygons( self, polygons_expected, dec=None, m="Incorrect Polygon Data" ): """Asserts the polygon data in Axes ax is equal to polygons_expected to decimal place dec with error message m - If polygons_expected is am empty list or None, assertion is passed. + If polygons_expected is am empty list or None, assertion is passed. Parameters ---------- polygons_expected : List or GeoDataFrame List of polygons expected to be founds on Axes ax or a GeoDataFrame containing the expected polygons. - dec : int (Optional) + dec : int (Optional) Int stating the desired decimal precision. If None, polygons must be exact. m : string (default = "Incorrect Polygon Data") String error message if assertion is not met. - """ + """ if len(polygons_expected) != 0: if isinstance(polygons_expected, list): if len(polygons_expected[0]) == 0: