diff --git a/tiny/rna/plotter.py b/tiny/rna/plotter.py index e0bd13bd..6ff4ad35 100644 --- a/tiny/rna/plotter.py +++ b/tiny/rna/plotter.py @@ -402,7 +402,9 @@ def scatter_by_dge_class(counts_avg_df, dges, output_prefix, view_lims, include= """ counts_avg_df, dges = filter_dge_classes(counts_avg_df, dges, include, exclude) - if counts_avg_df.empty or dges.empty: return + if counts_avg_df.empty or dges.empty: + print('ERROR: No classes passed filtering. Skipping scatter_by_dge_class.', file=sys.stderr) + return uniq_classes = pd.unique(counts_avg_df.index.get_level_values(1)) class_colors = aqplt.assign_class_colors(uniq_classes) @@ -442,6 +444,7 @@ def scatter_by_dge(counts_avg_df, dges, output_prefix, view_lims, pval=0.05): """ if counts_avg_df.empty or dges.empty: + print('ERROR: Received empty counts data. Skipping scatter_by_dge.', file=sys.stderr) return for pair in dges: @@ -683,23 +686,79 @@ def main(): with mp.Pool(len(itinerary)) as pool: results = [] for task, args, kwds in itinerary: - results.append(pool.apply_async(task, args, kwds, error_callback=err)) + sentry = ExceptionManager(task) + results.append(pool.apply_async(task, args, kwds, error_callback=sentry)) for result in results: result.wait() else: # Don't use multiprocessing if only one plot type requested - # or if in debug mode (matplotlib compatibility) + # or if in debug mode (for matplotlib compatibility) for task, args, kwds in itinerary: - task(*args, **kwds) + try: + task(*args, **kwds) + except Exception as e: + ExceptionManager.add(task, e) + + ExceptionManager.print_exceptions() + + +class ExceptionManager: + """Handles exception formatting for more user-friendly logging with cwltool + + In multiprocessing mode, you should create an instance for each task + (plot type) and exceptions will be stored at the class level for ALL tasks. + In sequential mode, you should use the add() method + + Exception tracebacks are printed to stdout as soon as they happen, and since + the CWL CommandLineTool for tiny-plot captures stdout, this goes to the log + file rather than terminal. Once plotting is complete, the user-friendly + error is printed to stderr which isn't captured, so the user sees it. + The message includes instructions for `tiny replot` followed by an + exception summary (sans noisy traceback), organized by task.""" + + excs = defaultdict(list) + + def __init__(self, task): + self.task = task -def err(e): - """Allows us to print errors from a MP worker without discarding the other results""" - print(''.join(traceback.format_exception(type(e), e, e.__traceback__))) - print("\n\nPlotter encountered an error. Don't worry! You don't have to start over.\n" - "You can resume the pipeline at Plotter. To do so:\n\t" + def __call__(self, e): + """The multiprocessing error_callback target""" + self.add(self.task, e, from_mp_worker=True) + + @classmethod + def add(cls, task, e, from_mp_worker=False): + """Prints task's traceback to stdout and stores exceptions for summary""" + + if from_mp_worker: + print(e.__cause__.tb) + else: + ex_tuple = (type(e), e, e.__traceback__) + traceback.print_exception(*ex_tuple) + + cls.excs[task].extend(traceback.format_exception_only(type(e), e)) + + @classmethod + def print_exceptions(cls): + """Prints exception summary to stderr for all tasks""" + + if not cls.excs: return + print('\n'.join(['', '=' * 75, '=' * 75]), file=sys.stderr) + print("\nPlotter encountered an error. Don't worry! You don't have to start over.\n" + "You can resume the pipeline at tiny-plot. To do so:\n\t" "1. cd into your Run Directory\n\t" '2. Run "tiny replot --config your_run_config.yml"\n\t' - ' (that\'s the processed run config) ^^^\n\n', file=sys.stderr) + ' (that\'s the processed run config) ^^^\n', file=sys.stderr) + + ex_sum = sum(len(ex) for ex in cls.excs.values()) + header = "The following {} reported:" + plural = "exceptions were" if ex_sum > 1 else "exception was" + + exc_list = [header.format(plural)] + for task, task_exceptions in cls.excs.items(): + exc_list.append('\t' + f"In function {task.__name__}():") + exc_list.append('\t\t'.join(['', *task_exceptions])) + + print('\n'.join(exc_list), file=sys.stderr) if __name__ == '__main__': diff --git a/tiny/rna/plotterlib.py b/tiny/rna/plotterlib.py index ebebcd81..fe0faba0 100644 --- a/tiny/rna/plotterlib.py +++ b/tiny/rna/plotterlib.py @@ -15,6 +15,7 @@ import os import re +# This has to be done before importing matplotlib.pyplot # cwltool appears to unset all environment variables including those related to locale # This leads to warnings from plt's FontConfig manager, but only for pipeline/cwl runs curr_locale = locale.getlocale() @@ -25,7 +26,6 @@ import matplotlib as mpl; mpl.use("PDF") import matplotlib.pyplot as plt import matplotlib.ticker as tix -import matplotlib.axis from matplotlib.patches import Rectangle from matplotlib.transforms import Bbox from matplotlib.scale import LogTransform @@ -33,6 +33,8 @@ from typing import Union, Tuple, List, Optional from abc import ABC, abstractmethod +from tiny.rna.util import sorted_natural + class plotterlib: @@ -233,65 +235,96 @@ def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame, *groups, gscat: A scatter plot containing groups highlighted with different colors """ - # Subset counts not in *groups (for example, points with p-val above threshold) + # Subset counts not in *groups (e.g., p-val above threshold) count_x_out = count_x.drop(itertools.chain(*groups)) count_y_out = count_y.drop(itertools.chain(*groups)) - - outgroup = count_x_out.any() and count_y_out.any() - group_it = iter(groups) - - if outgroup: - gscat = self.scatter_simple(count_x_out, count_y_out, color='#B3B3B3', **kwargs) - else: + has_outgroup = all(co.replace(0, pd.NA).dropna().any() + for co in (count_x_out, count_y_out)) + + # Determine which groups we are able to plot on log scale + plottable_groups = self.get_nonzero_group_indexes(count_x, count_y, groups) + plot_labels = [labels[i] for i in plottable_groups] + plot_groups = [groups[i] for i in plottable_groups] + group_it = iter(plot_groups) + + if has_outgroup: + x, y = count_x_out, count_y_out + gscat = self.scatter_simple(x, y, color='#B3B3B3', **kwargs) + elif plottable_groups: group = next(group_it) - gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group], **kwargs) + x, y = count_x.loc[group], count_y.loc[group] + gscat = self.scatter_simple(x, y, **kwargs) + else: + has_outgroup = None + x = y = pd.Series(dtype='float64') + gscat = self.scatter_simple(x, y, **kwargs) - # Add any remaining groups to the plot - zero_count_groups = [] - for i, group in enumerate(group_it): + # Add remaining groups + for group in group_it: x, y = count_x.loc[group], count_y.loc[group] - x_is_zeros = x.replace(0, pd.NA).dropna().empty - y_is_zeros = y.replace(0, pd.NA).dropna().empty - if x_is_zeros or y_is_zeros: - # This group and label won't be plotted - zero_count_groups.append(i) - continue gscat.scatter(x, y, edgecolor='none', **kwargs) - labels = [l for i, l in enumerate(labels) if i not in zero_count_groups] - groups = [g for i, g in enumerate(groups) if i not in zero_count_groups] - - self.sort_point_groups_and_label(gscat, groups, labels, colors, outgroup, pval) + self.sort_point_groups_and_label(gscat, plot_groups, plot_labels, colors, has_outgroup, pval) self.set_square_scatter_view_lims(gscat, view_lims) self.set_scatter_ticks(gscat) return gscat @staticmethod - def sort_point_groups_and_label(axes: plt.Axes, groups, labels, colors, outgroup, pval): - """Sorts scatter groups so that less abundant groups are plotted on top to maximize visual representation. - After sorting, group colors and labels are assigned, and the legend is created.""" + def get_nonzero_group_indexes(count_x, count_y, groups): + """When scatter plotting groups for two conditions on a log scale, if one + of the conditions has all zero counts for the group, then none of the group's + points are actually plotted due to the singularity at 0. We want to skip + plotting these groups and omit them from the legend.""" + + non_zero_groups = [] + for i, group in enumerate(groups): + x, y = count_x.loc[group], count_y.loc[group] + x_is_zeros = x.replace(0, pd.NA).dropna().empty + y_is_zeros = y.replace(0, pd.NA).dropna().empty + if not (x_is_zeros or y_is_zeros): + non_zero_groups.append(i) + + return non_zero_groups - lorder = np.argsort([len(grp) for grp in groups if len(grp)])[::-1] # Label index of groups sorted largest -> smallest - offset = int(bool(outgroup and len(groups))) # For shifting indices to allow optional outgroup + @staticmethod + def sort_point_groups_and_label(axes: plt.Axes, groups, labels, colors, outgroup: Optional[bool], pval): + """Sorts scatter groups so that those with fewer points are rendered on top of the stack. + After sorting, group colors and labels are assigned, and the legend is created. Labels + in the legend are sorted by natural order with the outgroup always listed last. + Args: + axes: The scatter plot Axes object + groups: A list of DataFrames that were able to be plotted + labels: A list of names, one for each group, for the corresponding index in `groups` + colors: A dictionary of group labels and their assigned colors + outgroup: True if an out group was plotted, None if empty plot (no groups or out groups) + """ + + lorder = np.argsort([len(grp) for grp in groups if len(grp)])[::-1] # Index of groups by size + offset = int(bool(outgroup)) layers = axes.collections if outgroup: layers[0].set_label('p ≥ %g' % pval) if labels is None: labels = list(range(len(groups))) + if outgroup is None: + return groupsize_sorted = [(labels[i], layers[i + offset]) for i in lorder] - for i, (label, layer) in enumerate(groupsize_sorted, start=1): + for z, (label, layer) in enumerate(groupsize_sorted, start=offset+1): layer.set_label(re.sub(r'^_', ' _', label)) # To allow labels that start with _ layer.set_facecolor(colors[label]) - layer.set_zorder(i) # Plot in order of group size + layer.set_zorder(z) # Plot in order of group size # Ensure lines remain on top of points for line in axes.lines: - line.set_zorder(len(groupsize_sorted) + 1) + line.set_zorder(len(layers) + 1) - axes.legend() + # Sort the legend with outgroup last while retaining layer order + handles = sorted_natural(layers[offset:], key=lambda x: x.get_label()) + if outgroup: handles.append(layers[0]) + axes.legend(handles=handles) @staticmethod def assign_class_colors(classes): diff --git a/tiny/rna/util.py b/tiny/rna/util.py index cd0e4b6b..3e1526b6 100644 --- a/tiny/rna/util.py +++ b/tiny/rna/util.py @@ -204,7 +204,8 @@ def __init__(self, rw_dict): def __setitem__(self, *_): raise RuntimeError("Attempted to modify read-only dictionary after construction.") -def sorted_natural(lines, reverse=False): + +def sorted_natural(lines, key=None, reverse=False): """Sorts alphanumeric strings with entire numbers considered in the sorting order, rather than the default behavior which is to sort by the individual ASCII values of the given number. Returns a sorted copy of the list, just like sorted(). @@ -213,7 +214,8 @@ def sorted_natural(lines, reverse=False): some time. Strange that there isn't something in the standard library for this.""" convert = lambda text: int(text) if text.isdigit() else text.lower() - alphanum_key = lambda key: [convert(c) for c in re.split(r'(\d+)', key)] + extract = (lambda data: key(data)) if key is not None else lambda x: x + alphanum_key = lambda elem: [convert(c) for c in re.split(r'(\d+)', extract(elem))] return sorted(lines, key=alphanum_key, reverse=reverse)