diff --git a/CHANGELOG.md b/CHANGELOG.md index 415c27fc..eca75ef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] * Created a vignette covering the testing of histograms (@ryla5068, #149) +* Created `get_plot_image()` function for the RasterTester object (@nkorinek, #192) * Allowed `assert_polygons()` to accept GeoDataFrames and added tests (@nkorinek, #175) ## [0.1.1] diff --git a/matplotcheck/raster.py b/matplotcheck/raster.py index 9e8aa5f0..4f62a8ce 100644 --- a/matplotcheck/raster.py +++ b/matplotcheck/raster.py @@ -164,6 +164,24 @@ def assert_legend_accuracy_classified_image( ### IMAGE TESTS/HELPER FUNCTIONS ### + def get_plot_image(self): + """Returns images stored on the Axes object as a list of numpy arrays. + + Returns + ------- + im_data: List + Numpy array of images stored on Axes object. + """ + im_data = [] + if self.ax.get_images(): + im_data = self.ax.get_images()[0].get_array() + assert list(im_data), "No Image Displayed" + + # If image array has 3 dims (e.g. rgb image), remove alpha channel + if len(im_data.shape) == 3: + im_data = im_data[:, :, :3] + return im_data + def assert_image( self, im_expected, im_classified=False, m="Incorrect Image Displayed" ): @@ -171,12 +189,14 @@ def assert_image( Parameters ---------- - im_expected: array containing the expected image data - im_classified: boolean, set to True image has been classified. - Since classified images values can be reversed or shifted and still - produce the same image, setting this to True will allow those - changes. - m: string error message if assertion is not met + im_expected: Numpy Array + Array containing the expected image data. + im_classified: boolean + Set to True image has been classified. Since classified images + values can be reversed or shifted and still produce the same image, + setting this to True will allow those changes. + m: string + String error message if assertion is not met. Returns ---------- diff --git a/matplotcheck/tests/test_raster.py b/matplotcheck/tests/test_raster.py index ab59e62e..dd623660 100644 --- a/matplotcheck/tests/test_raster.py +++ b/matplotcheck/tests/test_raster.py @@ -317,3 +317,9 @@ def test_raster_assert_image_fullscreen_blank(raster_plt_blank): with pytest.raises(AssertionError, match="No image found on axes"): raster_plt_blank.assert_image_full_screen() plt.close() + + +def test_get_plot_images(raster_plt_rgb): + """get_plot_image should get correct image from ax object""" + ax_im = raster_plt_rgb.get_plot_image() + raster_plt_rgb.assert_image(ax_im)