Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Unreleased
in them (@nkorinek, #182)
- added contributors file and updated README to remove that information
(@nkorinek, #121)
- Improved handling of datasets with different shapes in base.assert_xy() (@ryla5068, #233)
- Bug fix for handling object datatypes in base.assert_xy() (@ryla5068, #232)

0.1.2
-----
Expand Down
59 changes: 33 additions & 26 deletions matplotcheck/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
"""

import numpy as np
import matplotlib.dates as mdates
import matplotlib
from matplotlib.backend_bases import RendererBase
import math
from scipy import stats
import pandas as pd
import geopandas as gpd
import numbers
import geopandas as gpd


class InvalidPlotError(Exception):
Expand Down Expand Up @@ -766,7 +765,7 @@ def assert_no_legend_overlap(self, message="Legends overlap eachother"):

""" BASIC PLOT DATA FUNCTIONS """

def get_xy(self, points_only=False, xtime=False):
def get_xy(self, points_only=False):
"""Returns a pandas dataframe with columns "x" and "y" holding the x
and y coords on Axes `ax`

Expand All @@ -777,9 +776,6 @@ def get_xy(self, points_only=False, xtime=False):
points_only : boolean
Set ``True`` to check only points, set ``False`` to check all data
on plot.
xtime : boolean
Set equal to True if the x axis of the plot contains datetime
values

Returns
-------
Expand Down Expand Up @@ -820,9 +816,6 @@ def get_xy(self, points_only=False, xtime=False):
xy_data = xy_data[xy_data["x"] >= lims[0]]
xy_data = xy_data[xy_data["x"] <= lims[1]].reset_index(drop=True)

# change to datetime dtype if needed
if xtime:
xy_data["x"] = mdates.num2date(xy_data["x"])
return xy_data

def assert_xydata(
Expand All @@ -831,7 +824,6 @@ def assert_xydata(
xcol=None,
ycol=None,
points_only=False,
xtime=False,
xlabels=False,
tolerence=0,
message="Incorrect data values",
Expand All @@ -855,11 +847,6 @@ def assert_xydata(
points_only : boolean,
Set ``True`` to check only points, set ``False`` tp check all data
on plot.
xtime : boolean
Set ``True`` if the a-axis contains datetime values. Matplotlib
converts datetime objects to seconds? This parameter will ensure
the provided x col values are converted if they are datetime
elements.
xlabels : boolean
Set ``True`` if using x axis labels rather than x data. Instead of
comparing numbers in the x-column to expected, compares numbers or
Expand All @@ -869,7 +856,8 @@ def assert_xydata(
For example: Given a tolerance ``tolerence=0.1``, an expected value
``e``, and an actual value ``a``, this asserts
``abs(a - e) < (e * 0.1)``. (This uses `np.testing.assert_allclose`
with ``rtol=tolerence`` and ``atol=inf``.)
with ``rtol=tolerence`` and ``atol=inf``.) If using tolerance for
datetime data, units for tolerance will be in days.
message : string
The error message to be displayed if the xy-data does not match
`xy_expected`
Expand Down Expand Up @@ -905,7 +893,7 @@ def assert_xydata(
xy_expected, xcol=xcol, ycol=ycol, message=message
)
return
xy_data = self.get_xy(points_only=points_only, xtime=xtime)
xy_data = self.get_xy(points_only=points_only)

# Make sure the data are sorted the same
xy_data, xy_expected = (
Expand All @@ -914,8 +902,6 @@ def assert_xydata(
)

if tolerence > 0:
if xtime:
raise ValueError("tolerance must be 0 with datetime on x-axis")
np.testing.assert_allclose(
xy_data["x"],
xy_expected[xcol],
Expand All @@ -933,20 +919,41 @@ def assert_xydata(
"""We use `assert_array_max_ulp()` to compare the
two datasets because it is able to account for small errors in
floating point numbers, and it scales nicely between extremely
small or large numbers. We catch this error and throw our own so
that we can use our own message."""
small or large numbers. Because of the way that matplotlib stores
datetime data, this is essential for comparing high-precision
datetime data (i.e. millisecond or lower).

We catch this error and raise our own that is more relevant to
the assertion we are running."""
try:
np.testing.assert_array_max_ulp(
np.array(xy_data["x"]), np.array(xy_expected[xcol])
xy_data["x"].to_numpy(dtype=np.float64),
xy_expected[xcol].to_numpy(dtype=np.float64),
5,
)
except AssertionError:
# xy_data and xy_expected do not contain the same data
raise AssertionError(message)
except ValueError:
# xy_data and xy_expected do not have the same shape
raise ValueError(
"xy_data and xy_expected do not have the same shape"
)
try:
np.testing.assert_array_max_ulp(
np.array(xy_data["y"]), np.array(xy_expected[ycol])
xy_data["y"].to_numpy(dtype=np.float64),
xy_expected[ycol].to_numpy(dtype=np.float64),
5,
)

except AssertionError:
# xy_data and xy_expected do not contain the same data
raise AssertionError(message)
except ValueError:
# xy_data and xy_expected do not have the same shape
raise ValueError(
"xy_data and xy_expected do not have the same shape"
)

def assert_xlabel_ydata(
self, xy_expected, xcol, ycol, message="Incorrect Data"
Expand Down Expand Up @@ -1014,7 +1021,7 @@ def assert_xlabel_ydata(
if x_is_numeric:
try:
np.testing.assert_array_max_ulp(
np.array(xy_data["x"]), np.array(xy_expected[xcol])
np.array(xy_data["x"]), np.array(xy_expected[xcol]), 5
)
except AssertionError:
raise AssertionError(message)
Expand All @@ -1026,7 +1033,7 @@ def assert_xlabel_ydata(
# Testing y-data
try:
np.testing.assert_array_max_ulp(
np.array(xy_data["y"]), np.array(xy_expected[ycol])
np.array(xy_data["y"]), np.array(xy_expected[ycol]), 5
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what the 5 is here that was added? @nkorinek when you look at this PR can you please see if it makes sense to you to have 5 there?

)
except AssertionError:
raise AssertionError(message)
Expand Down Expand Up @@ -1166,7 +1173,7 @@ def get_num_bins(self):
overlapping or stacked histograms in the same
`matplotlib.axis.Axis` object, then this returns the number of bins
with unique edges. """
x_data = self.get_xy(xtime=False)["x"]
x_data = self.get_xy()["x"]
unique_x_data = list(set(x_data))
num_bins = len(unique_x_data)

Expand Down
82 changes: 74 additions & 8 deletions matplotcheck/tests/test_timeseries_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,75 @@
'''
def test_assert_xydata_timeseries(pt_time_line_plt, pd_df_timeseries):
"""Commenting this out for now as this requires a time series data object
this is failing because the time data needs to be in seconds like how
mpl saves it. """
pt_time_line_plt.assert_xydata(
pd_df_timeseries, xcol='time', ycol='A', xtime=True
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
from matplotcheck.timeseries import TimeSeriesTester

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These import should be in pep8 order, matplotcheck being imported last


@pytest.fixture
def pd_df_timeseries():
"""Create a pandas dataframe for testing, with timeseries in one column"""
return pd.DataFrame(
{
"time": pd.date_range(start="1/1/2018", periods=100),
"A": np.random.randint(0, 100, size=100),
}
)
'''


@pytest.fixture
def pt_time_line_plt(pd_df_timeseries):
"""Create timeseries line plot for testing"""
fig, ax = plt.subplots()
pd_df_timeseries.plot("time", "A", kind="line", ax=ax)
axis = plt.gca()

return TimeSeriesTester(axis)


def test_assert_xydata_timeseries(pt_time_line_plt):
"""Tests that assert_xydata() correctly passes with matching timeseries
data."""
data = pt_time_line_plt.get_xy()
pt_time_line_plt.assert_xydata(data, xcol="x", ycol="y")


def test_assert_xydata_timeseries_fails(pt_time_line_plt):
"""Tests that assert_xydata() correctly fails without matching timeseries
data."""
data = pt_time_line_plt.get_xy()
data.loc[0, "x"] = 100
with pytest.raises(AssertionError, match="Incorrect data values"):
pt_time_line_plt.assert_xydata(data, xcol="x", ycol="y")


def test_assert_xydata_timeseries_truncation_error(
pt_time_line_plt, pd_df_timeseries
):
"""Tests that assert_xydata() handles floating-point truncation error
gracefully for timeseries data."""

pt1 = pt_time_line_plt

# Create second plottester object with slightly different time values
# The change in values here should be small enough that it gets truncated
# in matplotlib's conversion of datetime data
for i in range(len(pd_df_timeseries)):
pd_df_timeseries.loc[i, "time"] = pd_df_timeseries.loc[
i, "time"
] + pd.Timedelta(1)
fig, ax2 = plt.subplots()
pd_df_timeseries.plot("time", "A", kind="line", ax=ax2)
pt2 = TimeSeriesTester(ax2)

# Test that the two datasets assert as equal
data1 = pt1.get_xy()
pt2.assert_xydata(data1, xcol="x", ycol="y")


def test_assert_xydata_timeseries_roundoff_error(pt_time_line_plt):
"""Tests that assert_xydata() handles floating-point roundoff error
gracefully for timeseries data."""
data = pt_time_line_plt.get_xy()
data.loc[0, "x"] = data.loc[0, "x"] + 0.00000000001

pt_time_line_plt.assert_xydata(data, xcol="x", ycol="y")