diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 88057a6d..8bbec950 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,8 @@ Unreleased (@nkorinek, #121) - Changed tolerance functionality from relative tolerance to absolute tolerance. (@ryla5068, #234) +- Improved handling of datasets with different shapes in base.assert_xy() (@ryla5068, #233) +- Bug fix for handling object datatypes in base.assert_xy() (@ryla5068, #232) 0.1.2 ----- diff --git a/matplotcheck/base.py b/matplotcheck/base.py index 3f896c47..a7c9213e 100644 --- a/matplotcheck/base.py +++ b/matplotcheck/base.py @@ -8,14 +8,13 @@ """ import numpy as np -import matplotlib.dates as mdates import matplotlib from matplotlib.backend_bases import RendererBase import math from scipy import stats import pandas as pd -import geopandas as gpd import numbers +import geopandas as gpd class InvalidPlotError(Exception): @@ -766,7 +765,7 @@ def assert_no_legend_overlap(self, message="Legends overlap eachother"): """ BASIC PLOT DATA FUNCTIONS """ - def get_xy(self, points_only=False, xtime=False): + def get_xy(self, points_only=False): """Returns a pandas dataframe with columns "x" and "y" holding the x and y coords on Axes `ax` @@ -777,9 +776,6 @@ def get_xy(self, points_only=False, xtime=False): points_only : boolean Set ``True`` to check only points, set ``False`` to check all data on plot. - xtime : boolean - Set equal to True if the x axis of the plot contains datetime - values Returns ------- @@ -820,9 +816,6 @@ def get_xy(self, points_only=False, xtime=False): xy_data = xy_data[xy_data["x"] >= lims[0]] xy_data = xy_data[xy_data["x"] <= lims[1]].reset_index(drop=True) - # change to datetime dtype if needed - if xtime: - xy_data["x"] = mdates.num2date(xy_data["x"]) return xy_data def assert_xydata( @@ -831,7 +824,6 @@ def assert_xydata( xcol=None, ycol=None, points_only=False, - xtime=False, xlabels=False, tolerance=0, message="Incorrect data values", @@ -856,11 +848,6 @@ def assert_xydata( points_only : boolean, Set ``True`` to check only points, set ``False`` tp check all data on plot. - xtime : boolean - Set ``True`` if the a-axis contains datetime values. Matplotlib - converts datetime objects to seconds? This parameter will ensure - the provided x col values are converted if they are datetime - elements. xlabels : boolean Set ``True`` if using x axis labels rather than x data. Instead of comparing numbers in the x-column to expected, compares numbers or @@ -906,7 +893,7 @@ def assert_xydata( xy_expected, xcol=xcol, ycol=ycol, message=message ) return - xy_data = self.get_xy(points_only=points_only, xtime=xtime) + xy_data = self.get_xy(points_only=points_only) # Make sure the data are sorted the same xy_data, xy_expected = ( @@ -915,8 +902,6 @@ def assert_xydata( ) if tolerance > 0: - if xtime: - raise ValueError("tolerance must be 0 with datetime on x-axis") np.testing.assert_allclose( xy_data["x"], xy_expected[xcol], @@ -934,20 +919,41 @@ def assert_xydata( """We use `assert_array_max_ulp()` to compare the two datasets because it is able to account for small errors in floating point numbers, and it scales nicely between extremely - small or large numbers. We catch this error and throw our own so - that we can use our own message.""" + small or large numbers. Because of the way that matplotlib stores + datetime data, this is essential for comparing high-precision + datetime data (i.e. millisecond or lower). + + We catch this error and raise our own that is more relevant to + the assertion being run.""" try: np.testing.assert_array_max_ulp( - np.array(xy_data["x"]), np.array(xy_expected[xcol]) + xy_data["x"].to_numpy(dtype=np.float64), + xy_expected[xcol].to_numpy(dtype=np.float64), + 5, ) except AssertionError: + # xy_data and xy_expected do not contain the same data raise AssertionError(message) + except ValueError: + # xy_data and xy_expected do not have the same shape + raise ValueError( + "xy_data and xy_expected do not have the same shape" + ) try: np.testing.assert_array_max_ulp( - np.array(xy_data["y"]), np.array(xy_expected[ycol]) + xy_data["y"].to_numpy(dtype=np.float64), + xy_expected[ycol].to_numpy(dtype=np.float64), + 5, ) + except AssertionError: + # xy_data and xy_expected do not contain the same data raise AssertionError(message) + except ValueError: + # xy_data and xy_expected do not have the same shape + raise ValueError( + "xy_data and xy_expected do not have the same shape" + ) def assert_xlabel_ydata( self, xy_expected, xcol, ycol, message="Incorrect Data" @@ -1015,7 +1021,7 @@ def assert_xlabel_ydata( if x_is_numeric: try: np.testing.assert_array_max_ulp( - np.array(xy_data["x"]), np.array(xy_expected[xcol]) + np.array(xy_data["x"]), np.array(xy_expected[xcol]), ) except AssertionError: raise AssertionError(message) @@ -1167,7 +1173,7 @@ def get_num_bins(self): overlapping or stacked histograms in the same `matplotlib.axis.Axis` object, then this returns the number of bins with unique edges. """ - x_data = self.get_xy(xtime=False)["x"] + x_data = self.get_xy()["x"] unique_x_data = list(set(x_data)) num_bins = len(unique_x_data) diff --git a/matplotcheck/tests/test_timeseries_module.py b/matplotcheck/tests/test_timeseries_module.py index b1be1448..e8179bf5 100644 --- a/matplotcheck/tests/test_timeseries_module.py +++ b/matplotcheck/tests/test_timeseries_module.py @@ -1,9 +1,75 @@ -''' -def test_assert_xydata_timeseries(pt_time_line_plt, pd_df_timeseries): - """Commenting this out for now as this requires a time series data object - this is failing because the time data needs to be in seconds like how - mpl saves it. """ - pt_time_line_plt.assert_xydata( - pd_df_timeseries, xcol='time', ycol='A', xtime=True +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest +from matplotcheck.timeseries import TimeSeriesTester + + +@pytest.fixture +def pd_df_timeseries(): + """Create a pandas dataframe for testing, with timeseries in one column""" + return pd.DataFrame( + { + "time": pd.date_range(start="1/1/2018", periods=100), + "A": np.random.randint(0, 100, size=100), + } ) -''' + + +@pytest.fixture +def pt_time_line_plt(pd_df_timeseries): + """Create timeseries line plot for testing""" + fig, ax = plt.subplots() + pd_df_timeseries.plot("time", "A", kind="line", ax=ax) + axis = plt.gca() + + return TimeSeriesTester(axis) + + +def test_assert_xydata_timeseries(pt_time_line_plt): + """Tests that assert_xydata() correctly passes with matching timeseries + data.""" + data = pt_time_line_plt.get_xy() + pt_time_line_plt.assert_xydata(data, xcol="x", ycol="y") + + +def test_assert_xydata_timeseries_fails(pt_time_line_plt): + """Tests that assert_xydata() correctly fails without matching timeseries + data.""" + data = pt_time_line_plt.get_xy() + data.loc[0, "x"] = 100 + with pytest.raises(AssertionError, match="Incorrect data values"): + pt_time_line_plt.assert_xydata(data, xcol="x", ycol="y") + + +def test_assert_xydata_timeseries_truncation_error( + pt_time_line_plt, pd_df_timeseries +): + """Tests that assert_xydata() handles floating-point truncation error + gracefully for timeseries data.""" + + pt1 = pt_time_line_plt + + # Create second plottester object with slightly different time values + # The change in values here should be small enough that it gets truncated + # in matplotlib's conversion of datetime data + for i in range(len(pd_df_timeseries)): + pd_df_timeseries.loc[i, "time"] = pd_df_timeseries.loc[ + i, "time" + ] + pd.Timedelta(1) + fig, ax2 = plt.subplots() + pd_df_timeseries.plot("time", "A", kind="line", ax=ax2) + pt2 = TimeSeriesTester(ax2) + + # Test that the two datasets assert as equal + data1 = pt1.get_xy() + pt2.assert_xydata(data1, xcol="x", ycol="y") + + +def test_assert_xydata_timeseries_roundoff_error(pt_time_line_plt): + """Tests that assert_xydata() handles floating-point roundoff error + gracefully for timeseries data.""" + data = pt_time_line_plt.get_xy() + data.loc[0, "x"] = data.loc[0, "x"] + 0.00000000001 + + pt_time_line_plt.assert_xydata(data, xcol="x", ycol="y")