Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions tests/unit_tests_plotter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import sys
import unittest
from unittest.mock import patch

import numpy as np
import pandas as pd

from unittest.mock import patch
from pandas.testing import assert_frame_equal
from pkg_resources import resource_filename

import tiny.rna.plotter as plotter
import tiny.rna.plotterlib as lib

class MyTestCase(unittest.TestCase):
class PlotterTests(unittest.TestCase):

@classmethod
def setUpClass(cls):
Expand All @@ -24,7 +26,10 @@ def get_label_width_pairs_from_annotations_mock(self, annotations):
def aqplt_mock(self):
return patch(
'tiny.rna.plotter.aqplt',
lib.plotterlib(resource_filename('tiny', 'templates/tinyrna-light.mplstyle'))
lib.plotterlib(
resource_filename('tiny', 'templates/tinyrna-light.mplstyle'),
**{'cache_scatter_ticks': False}
)
)

def get_empty_scatter_dge_dataframes(self):
Expand Down Expand Up @@ -178,6 +183,32 @@ def test_scatter_dge_class_empty_inclusive_filter(self):
plotter.scatter_by_dge_class(counts, dge, 'dummy_prefix', (0, 0), exclude=[])
scatter.assert_not_called()

"""Do scatter plots show the appropriate major ticks through a range of view limits?"""

@unittest.skip("Long-running test, execute manually if needed")
def test_scatter_major_ticks(self):
counts, dge = self.get_empty_scatter_dge_dataframes()
min_non_zero = 1 / sys.maxsize # avoid zero on log scale
fps = 3

counts.loc[('featA', 'featClassA'), 'ConditionA'] = min_non_zero
counts.loc[('featA', 'featClassA'), 'ConditionB'] = min_non_zero
dge.loc[('featA', 'featClassA'), 'ConditionA_vs_ConditionB'] = 0

for x in range(0, 121):
with self.aqplt_mock():
x /= fps # Range only produces integer values, but we want fractional powers of 2 in the demo

# Rolling view limit window
lo_bound = 2**(-6 + (x/2)) # Walk lower bound slowly forward from 2 decimal log2 minimum
hi_bound = 2**x + x # Walk upper bound forward much faster
vlim = np.array((lo_bound, hi_bound))

# title = f"Range: 2^{int(np.log2(view_lims[0]))} .. 2^{np.log2(view_lims[1]):.1f}"
# ^ must be set within scatter_* functions in plotter.py, not worth refactoring to support
plotter.scatter_by_dge(counts, dge, f'lim_{x:.2f}', vlim)



if __name__ == '__main__':
unittest.main()
32 changes: 14 additions & 18 deletions tiny/rna/plotterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

class plotterlib:

def __init__(self, user_style_sheet):
def __init__(self, user_style_sheet, **prefs):

self.debug = self.is_debug_mode()
if self.debug:
Expand All @@ -46,6 +46,7 @@ def __init__(self, user_style_sheet):
# Set global plot style once
plt.style.use(user_style_sheet)

self.prefs = prefs
self.subplot_cache = {}
self.dge_scatter_tick_cache = {}

Expand Down Expand Up @@ -383,8 +384,8 @@ def get_fixed_majorticklocs(view_lims: Tuple[float, float, float, float]) -> Tup
"""Produces a list of locations for major ticks for the given view limit"""

ax_min, ax_max = min(view_lims), max(view_lims)
floor, ceil, log2 = math.floor, math.ceil, np.log2
locs = [2 ** x for x in range(floor(log2(ax_min)), ceil(log2(ax_max)))]
ceil, log2 = math.ceil, np.log2
locs = [2 ** x for x in range(ceil(log2(ax_min)), ceil(log2(ax_max)))]
return locs, ax_min, ax_max

def set_scatter_ticks(self, ax: plt.Axes, minor_ticks=False):
Expand All @@ -402,15 +403,16 @@ def set_scatter_ticks(self, ax: plt.Axes, minor_ticks=False):

for axis in [ax.xaxis, ax.yaxis]:
# Only display every nth major tick label
ticks_displayed, last_idx = self.every_nth_label(axis, 3)
n = int(np.log2(len(major_locs)) - 1)
ticks_displayed, last_idx = self.every_nth_label(axis, n)

if minor_ticks:
axis.set_minor_locator(tix.LogLocator(
base=2.0,
numticks=self.get_min_LogLocator_numticks(axis),
subs=np.log2(np.linspace(2 ** 2, 2 ** 4, 10))[:-1]))

min_tick = 2 ** (np.log2(ax_min)+1)
min_tick = ax_min
max_tick = major_locs[last_idx]
self.set_tick_bounds(axis, min_tick=min_tick, max_tick=max_tick, minor=minor_ticks)
self.cache_ticks(axis, axis.__name__)
Expand All @@ -429,16 +431,6 @@ def every_nth_label(self, axis: mpl.axis.Axis, n: int) -> Tuple[List[mpl.axis.Ti
ticks_displayed.append(tick)
last_idx = i

# Hide tick labels in the lower left corner, regardless
major_ticks[0].label1.set_visible(False)

# If the last tick label on the x-axis will extend past the plot space,
# then hide it and its corresponding tick on the y-axis
if axis.__name__ == "xaxis" and axis.get_tick_space() == len(ticks_displayed):
major_ticks[last_idx].label1.set_visible(False)
yaxis = axis.axes.yaxis
yaxis.get_major_ticks()[last_idx].label1.set_visible(False)

return ticks_displayed, last_idx

def set_tick_bounds(self, axis: mpl.axis.Axis, min_tick: float, max_tick: float, minor=False):
Expand Down Expand Up @@ -466,9 +458,10 @@ def set_tick_bounds(self, axis: mpl.axis.Axis, min_tick: float, max_tick: float,
def cache_ticks(self, axis: mpl.axis.Axis, name: str):
"""Cache major and minor tick objects, which contain expensive data"""

for type in ["major", "minor"]:
self.dge_scatter_tick_cache[f"{name}_{type}_loc"] = getattr(axis, type).locator
self.dge_scatter_tick_cache[f"{name}_{type}_tix"] = getattr(axis, f"{type}Ticks")
if self.prefs.get('cache_scatter_ticks', True):
for type in ["major", "minor"]:
self.dge_scatter_tick_cache[f"{name}_{type}_loc"] = getattr(axis, type).locator
self.dge_scatter_tick_cache[f"{name}_{type}_tix"] = getattr(axis, f"{type}Ticks")

def restore_ticks(self, ax: plt.Axes, axis: str):
"""Restore tick objects from previous render"""
Expand Down Expand Up @@ -630,6 +623,9 @@ class CacheBase(ABC):
@abstractmethod
def get(self): pass

def __del__(self):
plt.close(self.fig)


class ClassChartCache(CacheBase):
def __init__(self):
Expand Down