diff --git a/tiny/rna/plotter.py b/tiny/rna/plotter.py index 06283976..7ab2ca14 100644 --- a/tiny/rna/plotter.py +++ b/tiny/rna/plotter.py @@ -291,7 +291,7 @@ def scatter_replicates(count_df: pd.DataFrame, samples: dict, output_prefix: str for samp, reps in samples.items(): for pair in itertools.combinations(reps, 2): rscat = aqplt.scatter_simple(count_df.loc[:,pair[0]], count_df.loc[:,pair[1]], - color='#B3B3B3', alpha=0.5, log_norm=True, rasterized=RASTER) + color='#B3B3B3', alpha=0.5, rasterized=RASTER) aqplt.set_square_scatter_view_lims(rscat, view_lims) aqplt.set_scatter_ticks(rscat) rscat.set_title(samp) @@ -347,10 +347,12 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un output_prefix: A string to use as a prefix for saving files classes: An optional dataframe of class(es) per feature. If provided, points are grouped by class show_unknown: If true, class "unknown" will be included if highlighting by classes + pval: The pvalue threshold for determining the outgroup """ if classes is not None: - uniq_classes = sorted(list(pd.unique(classes))) + uniq_classes = sorted(list(pd.unique(classes)), key=str.lower) + class_colors = aqplt.assign_class_colors(uniq_classes) aqplt.set_dge_class_legend_style() if not show_unknown and 'unknown' in uniq_classes: @@ -358,18 +360,17 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un for pair in dges: p1, p2 = pair.split("_vs_") + # Get list of differentially expressed features for this comparison pair dge_list = list(dges.index[dges[pair] < pval]) + # Create series of feature -> class relationships class_dges = classes.loc[dge_list] - + # Gather lists of features by class (listed in order corresponding to unique_classes) grp_args = [class_dges.index[class_dges == cls].tolist() for cls in uniq_classes] - layer_order = np.argsort([len(grp) for grp in grp_args])[::-1] - sorted_grps = [grp_args[i] for i in layer_order] - sorted_clss = [uniq_classes[i] for i in layer_order] + labels = uniq_classes + sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], *grp_args, colors=class_colors, + pval=pval, view_lims=view_lims, labels=labels, rasterized=RASTER) - labels = ['p ≥ %g' % pval] + sorted_clss - sscat = aqplt.scatter_grouped(count_df.loc[:,p1], count_df.loc[:,p2], view_lims, *sorted_grps, - log_norm=True, labels=labels, rasterized=RASTER) sscat.set_title('%s vs %s' % (p1, p2)) sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1) sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2) @@ -382,9 +383,11 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un grp_args = list(dges.index[dges[pair] < pval]) p1, p2 = pair.split("_vs_") - labels = ['p ≥ %g' % pval, 'p < %g' % pval] - sscat = aqplt.scatter_grouped(count_df.loc[:,p1], count_df.loc[:,p2], view_lims, grp_args, - log_norm=True, labels=labels, alpha=0.5, rasterized=RASTER) + labels = ['p < %g' % pval] + colors = aqplt.assign_class_colors(labels) + sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], grp_args, colors=colors, alpha=0.5, + pval=pval, view_lims=view_lims, labels=labels, rasterized=RASTER) + sscat.set_title('%s vs %s' % (p1, p2)) sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1) sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2) diff --git a/tiny/rna/plotterlib.py b/tiny/rna/plotterlib.py index 36fb7710..3444b3c7 100644 --- a/tiny/rna/plotterlib.py +++ b/tiny/rna/plotterlib.py @@ -160,13 +160,13 @@ def barh_proportion(self, prop_ds: pd.Series, max_prop=1.0, scale=2, **kwargs) - return cbar - def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False, **kwargs) -> plt.Axes: + def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_scale=True, **kwargs) -> plt.Axes: """Creates a simple scatter plot of counts. Args: count_x: A pandas dataframe/series of counts per feature (X axis) count_y: A pandas dataframe/series of counts per feature (Y axis) - log_norm: Plot on log scale rather than linear scale + log_scale: Plot on log scale rather than linear scale kwargs: Additional keyword arguments to pass to pyplot.Axes.scatter() Returns: @@ -178,7 +178,7 @@ def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False, ax: plt.Axes # log2 normalize data if requested - if log_norm: + if log_scale: # Set log2 scale ax.set_xscale('log', base=2) ax.set_yscale('log', base=2) @@ -211,51 +211,93 @@ def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False, return ax - def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame, - view_lims: Tuple[float, float] = None, *args, log_norm=False, labels=None, **kwargs): + def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame, *groups, colors: dict, pval=0.05, + view_lims: Tuple[float, float] = None, labels=None, **kwargs): """Creates a scatter plot with different groups highlighted. Args: count_x: A pandas dataframe/series of counts per feature (X axis) count_y: A pandas dataframe/series of counts per feature (Y axis) + groups: An iterable of lists each representing a group + labels: An iterable of labels corresponding to each group + colors: A dictionary of label -> color for each group view_lims: Optional plot view limits as tuple(min, max) - args: A list of features to highlight, can pass multiple lists - log_norm: whether or not the data should be log-normalized + pval: The p-value to use for the outgroup label + + Keyword Args: + log_scale: Data is plotted on log scale if true (default: true) kwargs: Additional arguments to pass to pyplot.Axes.scatter() Returns: gscat: A scatter plot containing groups highlighted with different colors """ - # Subset counts not in *args (for example, points with p-val above threshold) - count_x_base = count_x.drop(itertools.chain(*args)) - count_y_base = count_y.drop(itertools.chain(*args)) - - if labels is None: - labels = list(range(len(args))) + # Subset counts not in *groups (for example, points with p-val above threshold) + count_x_out = count_x.drop(itertools.chain(*groups)) + count_y_out = count_y.drop(itertools.chain(*groups)) - colors = iter(kwargs.get('colors', plt.rcParams['axes.prop_cycle'].by_key()['color'])) - argsit = iter(args) + outgroup = count_x_out.any() and count_y_out.any() + group_it = iter(groups) - if any([len(outgroup) == 0 for outgroup in [count_x_base, count_y_base]]): - # There is no outgroup, plot the first group with scatter_simple() to set scale and lines - group = next(argsit) - gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group], - log_norm=log_norm, color=next(colors), **kwargs) + if outgroup: + gscat = self.scatter_simple(count_x_out, count_y_out, color='#B3B3B3', **kwargs) else: - # Plot the outgroup in light grey (these are counts not in *args) - gscat = self.scatter_simple(count_x_base,count_y_base, log_norm=log_norm, color='#B3B3B3', **kwargs) + group = next(group_it) + gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group], **kwargs) - # Add each group to plot with a different color - for group in argsit: - gscat.scatter(count_x.loc[group], count_y.loc[group], color=next(colors), edgecolor='none', **kwargs) + # Add any remaining groups to the plot + for group in group_it: + gscat.scatter(count_x.loc[group], count_y.loc[group], edgecolor='none', **kwargs) + self.sort_point_groups_and_label(gscat, groups, labels, colors, outgroup, pval) self.set_square_scatter_view_lims(gscat, view_lims) self.set_scatter_ticks(gscat) - gscat.legend(labels=labels) + return gscat + @staticmethod + def sort_point_groups_and_label(axes: plt.Axes, groups, labels, colors, outgroup, pval): + """Sorts scatter groups so that less abundant groups are plotted on top to maximize visual representation. + After sorting, group colors and labels are assigned, and the legend is created.""" + + lorder = np.argsort([len(grp) for grp in groups])[::-1] # Label index of groups sorted largest -> smallest + ordmax = lorder.max() # Length of groups, sans outgroup + zorder = ordmax - lorder # Z-order for groups (largest values on top) + gorder = lorder + 1 if outgroup else lorder # Group index of sorted groups, sans outgroup + + if outgroup: + axes.collections[0].set_label('p ≥ %g' % pval) + if labels is None: + labels = list(range(len(groups))) + + for G, L, Z in zip(gorder, lorder, zorder): + points = axes.collections[G] + label = labels[L] + points.set_facecolor(colors[label]) + points.set_label(label) + points.set_zorder(Z) + + # Ensure lines remain on top of points + for line in axes.lines: + line.set_zorder(ordmax + 1) + + axes.legend() + + @staticmethod + def assign_class_colors(classes): + """Assigns a color to each class for consistency across samples""" + + stylesheet_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + if len(classes) <= len(stylesheet_colors): + colors = iter(stylesheet_colors) + else: + colors = iter(plt.get_cmap("tab20")) + + return {cls: next(colors) for cls in classes} + def set_dge_class_legend_style(self): + """Widens the "scatter" figure and moves plot to the left to accommodate legend""" + expand_width_inches = 3 fig, scatter = self.reuse_subplot("scatter")