Skip to content
79 changes: 69 additions & 10 deletions tiny/rna/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ def scatter_by_dge_class(counts_avg_df, dges, output_prefix, view_lims, include=
"""

counts_avg_df, dges = filter_dge_classes(counts_avg_df, dges, include, exclude)
if counts_avg_df.empty or dges.empty: return
if counts_avg_df.empty or dges.empty:
print('ERROR: No classes passed filtering. Skipping scatter_by_dge_class.', file=sys.stderr)
return

uniq_classes = pd.unique(counts_avg_df.index.get_level_values(1))
class_colors = aqplt.assign_class_colors(uniq_classes)
Expand Down Expand Up @@ -442,6 +444,7 @@ def scatter_by_dge(counts_avg_df, dges, output_prefix, view_lims, pval=0.05):
"""

if counts_avg_df.empty or dges.empty:
print('ERROR: Received empty counts data. Skipping scatter_by_dge.', file=sys.stderr)
return

for pair in dges:
Expand Down Expand Up @@ -683,23 +686,79 @@ def main():
with mp.Pool(len(itinerary)) as pool:
results = []
for task, args, kwds in itinerary:
results.append(pool.apply_async(task, args, kwds, error_callback=err))
sentry = ExceptionManager(task)
results.append(pool.apply_async(task, args, kwds, error_callback=sentry))
for result in results:
result.wait()
else:
# Don't use multiprocessing if only one plot type requested
# or if in debug mode (matplotlib compatibility)
# or if in debug mode (for matplotlib compatibility)
for task, args, kwds in itinerary:
task(*args, **kwds)
try:
task(*args, **kwds)
except Exception as e:
ExceptionManager.add(task, e)

ExceptionManager.print_exceptions()


class ExceptionManager:
"""Handles exception formatting for more user-friendly logging with cwltool

In multiprocessing mode, you should create an instance for each task
(plot type) and exceptions will be stored at the class level for ALL tasks.
In sequential mode, you should use the add() method

Exception tracebacks are printed to stdout as soon as they happen, and since
the CWL CommandLineTool for tiny-plot captures stdout, this goes to the log
file rather than terminal. Once plotting is complete, the user-friendly
error is printed to stderr which isn't captured, so the user sees it.
The message includes instructions for `tiny replot` followed by an
exception summary (sans noisy traceback), organized by task."""

excs = defaultdict(list)

def __init__(self, task):
self.task = task

def err(e):
"""Allows us to print errors from a MP worker without discarding the other results"""
print(''.join(traceback.format_exception(type(e), e, e.__traceback__)))
print("\n\nPlotter encountered an error. Don't worry! You don't have to start over.\n"
"You can resume the pipeline at Plotter. To do so:\n\t"
def __call__(self, e):
"""The multiprocessing error_callback target"""
self.add(self.task, e, from_mp_worker=True)

@classmethod
def add(cls, task, e, from_mp_worker=False):
"""Prints task's traceback to stdout and stores exceptions for summary"""

if from_mp_worker:
print(e.__cause__.tb)
else:
ex_tuple = (type(e), e, e.__traceback__)
traceback.print_exception(*ex_tuple)

cls.excs[task].extend(traceback.format_exception_only(type(e), e))

@classmethod
def print_exceptions(cls):
"""Prints exception summary to stderr for all tasks"""

if not cls.excs: return
print('\n'.join(['', '=' * 75, '=' * 75]), file=sys.stderr)
print("\nPlotter encountered an error. Don't worry! You don't have to start over.\n"
"You can resume the pipeline at tiny-plot. To do so:\n\t"
"1. cd into your Run Directory\n\t"
'2. Run "tiny replot --config your_run_config.yml"\n\t'
' (that\'s the processed run config) ^^^\n\n', file=sys.stderr)
' (that\'s the processed run config) ^^^\n', file=sys.stderr)

ex_sum = sum(len(ex) for ex in cls.excs.values())
header = "The following {} reported:"
plural = "exceptions were" if ex_sum > 1 else "exception was"

exc_list = [header.format(plural)]
for task, task_exceptions in cls.excs.items():
exc_list.append('\t' + f"In function {task.__name__}():")
exc_list.append('\t\t'.join(['', *task_exceptions]))

print('\n'.join(exc_list), file=sys.stderr)


if __name__ == '__main__':
Expand Down
97 changes: 65 additions & 32 deletions tiny/rna/plotterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import re

# This has to be done before importing matplotlib.pyplot
# cwltool appears to unset all environment variables including those related to locale
# This leads to warnings from plt's FontConfig manager, but only for pipeline/cwl runs
curr_locale = locale.getlocale()
Expand All @@ -25,14 +26,15 @@
import matplotlib as mpl; mpl.use("PDF")
import matplotlib.pyplot as plt
import matplotlib.ticker as tix
import matplotlib.axis
from matplotlib.patches import Rectangle
from matplotlib.transforms import Bbox
from matplotlib.scale import LogTransform

from typing import Union, Tuple, List, Optional
from abc import ABC, abstractmethod

from tiny.rna.util import sorted_natural


class plotterlib:

Expand Down Expand Up @@ -233,65 +235,96 @@ def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame, *groups,
gscat: A scatter plot containing groups highlighted with different colors
"""

# Subset counts not in *groups (for example, points with p-val above threshold)
# Subset counts not in *groups (e.g., p-val above threshold)
count_x_out = count_x.drop(itertools.chain(*groups))
count_y_out = count_y.drop(itertools.chain(*groups))

outgroup = count_x_out.any() and count_y_out.any()
group_it = iter(groups)

if outgroup:
gscat = self.scatter_simple(count_x_out, count_y_out, color='#B3B3B3', **kwargs)
else:
has_outgroup = all(co.replace(0, pd.NA).dropna().any()
for co in (count_x_out, count_y_out))

# Determine which groups we are able to plot on log scale
plottable_groups = self.get_nonzero_group_indexes(count_x, count_y, groups)
plot_labels = [labels[i] for i in plottable_groups]
plot_groups = [groups[i] for i in plottable_groups]
group_it = iter(plot_groups)

if has_outgroup:
x, y = count_x_out, count_y_out
gscat = self.scatter_simple(x, y, color='#B3B3B3', **kwargs)
elif plottable_groups:
group = next(group_it)
gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group], **kwargs)
x, y = count_x.loc[group], count_y.loc[group]
gscat = self.scatter_simple(x, y, **kwargs)
else:
has_outgroup = None
x = y = pd.Series(dtype='float64')
gscat = self.scatter_simple(x, y, **kwargs)

# Add any remaining groups to the plot
zero_count_groups = []
for i, group in enumerate(group_it):
# Add remaining groups
for group in group_it:
x, y = count_x.loc[group], count_y.loc[group]
x_is_zeros = x.replace(0, pd.NA).dropna().empty
y_is_zeros = y.replace(0, pd.NA).dropna().empty
if x_is_zeros or y_is_zeros:
# This group and label won't be plotted
zero_count_groups.append(i)
continue
gscat.scatter(x, y, edgecolor='none', **kwargs)

labels = [l for i, l in enumerate(labels) if i not in zero_count_groups]
groups = [g for i, g in enumerate(groups) if i not in zero_count_groups]

self.sort_point_groups_and_label(gscat, groups, labels, colors, outgroup, pval)
self.sort_point_groups_and_label(gscat, plot_groups, plot_labels, colors, has_outgroup, pval)
self.set_square_scatter_view_lims(gscat, view_lims)
self.set_scatter_ticks(gscat)

return gscat

@staticmethod
def sort_point_groups_and_label(axes: plt.Axes, groups, labels, colors, outgroup, pval):
"""Sorts scatter groups so that less abundant groups are plotted on top to maximize visual representation.
After sorting, group colors and labels are assigned, and the legend is created."""
def get_nonzero_group_indexes(count_x, count_y, groups):
"""When scatter plotting groups for two conditions on a log scale, if one
of the conditions has all zero counts for the group, then none of the group's
points are actually plotted due to the singularity at 0. We want to skip
plotting these groups and omit them from the legend."""

non_zero_groups = []
for i, group in enumerate(groups):
x, y = count_x.loc[group], count_y.loc[group]
x_is_zeros = x.replace(0, pd.NA).dropna().empty
y_is_zeros = y.replace(0, pd.NA).dropna().empty
if not (x_is_zeros or y_is_zeros):
non_zero_groups.append(i)

return non_zero_groups

lorder = np.argsort([len(grp) for grp in groups if len(grp)])[::-1] # Label index of groups sorted largest -> smallest
offset = int(bool(outgroup and len(groups))) # For shifting indices to allow optional outgroup
@staticmethod
def sort_point_groups_and_label(axes: plt.Axes, groups, labels, colors, outgroup: Optional[bool], pval):
"""Sorts scatter groups so that those with fewer points are rendered on top of the stack.
After sorting, group colors and labels are assigned, and the legend is created. Labels
in the legend are sorted by natural order with the outgroup always listed last.
Args:
axes: The scatter plot Axes object
groups: A list of DataFrames that were able to be plotted
labels: A list of names, one for each group, for the corresponding index in `groups`
colors: A dictionary of group labels and their assigned colors
outgroup: True if an out group was plotted, None if empty plot (no groups or out groups)
"""

lorder = np.argsort([len(grp) for grp in groups if len(grp)])[::-1] # Index of groups by size
offset = int(bool(outgroup))
layers = axes.collections

if outgroup:
layers[0].set_label('p ≥ %g' % pval)
if labels is None:
labels = list(range(len(groups)))
if outgroup is None:
return

groupsize_sorted = [(labels[i], layers[i + offset]) for i in lorder]
for i, (label, layer) in enumerate(groupsize_sorted, start=1):
for z, (label, layer) in enumerate(groupsize_sorted, start=offset+1):
layer.set_label(re.sub(r'^_', ' _', label)) # To allow labels that start with _
layer.set_facecolor(colors[label])
layer.set_zorder(i) # Plot in order of group size
layer.set_zorder(z) # Plot in order of group size

# Ensure lines remain on top of points
for line in axes.lines:
line.set_zorder(len(groupsize_sorted) + 1)
line.set_zorder(len(layers) + 1)

axes.legend()
# Sort the legend with outgroup last while retaining layer order
handles = sorted_natural(layers[offset:], key=lambda x: x.get_label())
if outgroup: handles.append(layers[0])
axes.legend(handles=handles)

@staticmethod
def assign_class_colors(classes):
Expand Down
6 changes: 4 additions & 2 deletions tiny/rna/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def __init__(self, rw_dict):
def __setitem__(self, *_):
raise RuntimeError("Attempted to modify read-only dictionary after construction.")

def sorted_natural(lines, reverse=False):

def sorted_natural(lines, key=None, reverse=False):
"""Sorts alphanumeric strings with entire numbers considered in the sorting order,
rather than the default behavior which is to sort by the individual ASCII values
of the given number. Returns a sorted copy of the list, just like sorted().
Expand All @@ -213,7 +214,8 @@ def sorted_natural(lines, reverse=False):
some time. Strange that there isn't something in the standard library for this."""

convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split(r'(\d+)', key)]
extract = (lambda data: key(data)) if key is not None else lambda x: x
alphanum_key = lambda elem: [convert(c) for c in re.split(r'(\d+)', extract(elem))]
return sorted(lines, key=alphanum_key, reverse=reverse)


Expand Down