diff --git a/tests/unit_tests_plotter.py b/tests/unit_tests_plotter.py index 09b03884..9939dde0 100644 --- a/tests/unit_tests_plotter.py +++ b/tests/unit_tests_plotter.py @@ -1,14 +1,16 @@ +import sys import unittest -from unittest.mock import patch - +import numpy as np import pandas as pd + +from unittest.mock import patch from pandas.testing import assert_frame_equal from pkg_resources import resource_filename import tiny.rna.plotter as plotter import tiny.rna.plotterlib as lib -class MyTestCase(unittest.TestCase): +class PlotterTests(unittest.TestCase): @classmethod def setUpClass(cls): @@ -24,7 +26,10 @@ def get_label_width_pairs_from_annotations_mock(self, annotations): def aqplt_mock(self): return patch( 'tiny.rna.plotter.aqplt', - lib.plotterlib(resource_filename('tiny', 'templates/tinyrna-light.mplstyle')) + lib.plotterlib( + resource_filename('tiny', 'templates/tinyrna-light.mplstyle'), + **{'cache_scatter_ticks': False} + ) ) def get_empty_scatter_dge_dataframes(self): @@ -178,6 +183,32 @@ def test_scatter_dge_class_empty_inclusive_filter(self): plotter.scatter_by_dge_class(counts, dge, 'dummy_prefix', (0, 0), exclude=[]) scatter.assert_not_called() + """Do scatter plots show the appropriate major ticks through a range of view limits?""" + + @unittest.skip("Long-running test, execute manually if needed") + def test_scatter_major_ticks(self): + counts, dge = self.get_empty_scatter_dge_dataframes() + min_non_zero = 1 / sys.maxsize # avoid zero on log scale + fps = 3 + + counts.loc[('featA', 'featClassA'), 'ConditionA'] = min_non_zero + counts.loc[('featA', 'featClassA'), 'ConditionB'] = min_non_zero + dge.loc[('featA', 'featClassA'), 'ConditionA_vs_ConditionB'] = 0 + + for x in range(0, 121): + with self.aqplt_mock(): + x /= fps # Range only produces integer values, but we want fractional powers of 2 in the demo + + # Rolling view limit window + lo_bound = 2**(-6 + (x/2)) # Walk lower bound slowly forward from 2 decimal log2 minimum + hi_bound = 2**x + x # Walk upper bound forward much faster + vlim = np.array((lo_bound, hi_bound)) + + # title = f"Range: 2^{int(np.log2(view_lims[0]))} .. 2^{np.log2(view_lims[1]):.1f}" + # ^ must be set within scatter_* functions in plotter.py, not worth refactoring to support + plotter.scatter_by_dge(counts, dge, f'lim_{x:.2f}', vlim) + + if __name__ == '__main__': unittest.main() diff --git a/tiny/rna/plotterlib.py b/tiny/rna/plotterlib.py index 5fc30a6e..e5afb74e 100644 --- a/tiny/rna/plotterlib.py +++ b/tiny/rna/plotterlib.py @@ -36,7 +36,7 @@ class plotterlib: - def __init__(self, user_style_sheet): + def __init__(self, user_style_sheet, **prefs): self.debug = self.is_debug_mode() if self.debug: @@ -46,6 +46,7 @@ def __init__(self, user_style_sheet): # Set global plot style once plt.style.use(user_style_sheet) + self.prefs = prefs self.subplot_cache = {} self.dge_scatter_tick_cache = {} @@ -383,8 +384,8 @@ def get_fixed_majorticklocs(view_lims: Tuple[float, float, float, float]) -> Tup """Produces a list of locations for major ticks for the given view limit""" ax_min, ax_max = min(view_lims), max(view_lims) - floor, ceil, log2 = math.floor, math.ceil, np.log2 - locs = [2 ** x for x in range(floor(log2(ax_min)), ceil(log2(ax_max)))] + ceil, log2 = math.ceil, np.log2 + locs = [2 ** x for x in range(ceil(log2(ax_min)), ceil(log2(ax_max)))] return locs, ax_min, ax_max def set_scatter_ticks(self, ax: plt.Axes, minor_ticks=False): @@ -402,7 +403,8 @@ def set_scatter_ticks(self, ax: plt.Axes, minor_ticks=False): for axis in [ax.xaxis, ax.yaxis]: # Only display every nth major tick label - ticks_displayed, last_idx = self.every_nth_label(axis, 3) + n = int(np.log2(len(major_locs)) - 1) + ticks_displayed, last_idx = self.every_nth_label(axis, n) if minor_ticks: axis.set_minor_locator(tix.LogLocator( @@ -410,7 +412,7 @@ def set_scatter_ticks(self, ax: plt.Axes, minor_ticks=False): numticks=self.get_min_LogLocator_numticks(axis), subs=np.log2(np.linspace(2 ** 2, 2 ** 4, 10))[:-1])) - min_tick = 2 ** (np.log2(ax_min)+1) + min_tick = ax_min max_tick = major_locs[last_idx] self.set_tick_bounds(axis, min_tick=min_tick, max_tick=max_tick, minor=minor_ticks) self.cache_ticks(axis, axis.__name__) @@ -429,16 +431,6 @@ def every_nth_label(self, axis: mpl.axis.Axis, n: int) -> Tuple[List[mpl.axis.Ti ticks_displayed.append(tick) last_idx = i - # Hide tick labels in the lower left corner, regardless - major_ticks[0].label1.set_visible(False) - - # If the last tick label on the x-axis will extend past the plot space, - # then hide it and its corresponding tick on the y-axis - if axis.__name__ == "xaxis" and axis.get_tick_space() == len(ticks_displayed): - major_ticks[last_idx].label1.set_visible(False) - yaxis = axis.axes.yaxis - yaxis.get_major_ticks()[last_idx].label1.set_visible(False) - return ticks_displayed, last_idx def set_tick_bounds(self, axis: mpl.axis.Axis, min_tick: float, max_tick: float, minor=False): @@ -466,9 +458,10 @@ def set_tick_bounds(self, axis: mpl.axis.Axis, min_tick: float, max_tick: float, def cache_ticks(self, axis: mpl.axis.Axis, name: str): """Cache major and minor tick objects, which contain expensive data""" - for type in ["major", "minor"]: - self.dge_scatter_tick_cache[f"{name}_{type}_loc"] = getattr(axis, type).locator - self.dge_scatter_tick_cache[f"{name}_{type}_tix"] = getattr(axis, f"{type}Ticks") + if self.prefs.get('cache_scatter_ticks', True): + for type in ["major", "minor"]: + self.dge_scatter_tick_cache[f"{name}_{type}_loc"] = getattr(axis, type).locator + self.dge_scatter_tick_cache[f"{name}_{type}_tix"] = getattr(axis, f"{type}Ticks") def restore_ticks(self, ax: plt.Axes, axis: str): """Restore tick objects from previous render""" @@ -630,6 +623,9 @@ class CacheBase(ABC): @abstractmethod def get(self): pass + def __del__(self): + plt.close(self.fig) + class ClassChartCache(CacheBase): def __init__(self):