From bef4dc4814ea78b8ae7827cc7c211a3cf9f9fd8d Mon Sep 17 00:00:00 2001 From: Alex Tate <0xalextate@gmail.com> Date: Tue, 2 Aug 2022 19:23:32 -0700 Subject: [PATCH] Refactor of plotterlib.scatter_grouped() is complete. This fixes class color assignment so that colors are consistent across samples. This method and its signature have been cleaned up a bit. This has helped simplify plotter.scatter_dges() as well tiny-plot will also fail over to a 20 item colormap if the number of groups exceeds the number of colors defined in the user's stylesheet. Categorical color maps over 20 items aren't offered by MPL, but if needed the map will repeat as necessary --- tiny/rna/plotter.py | 27 ++++++------ tiny/rna/plotterlib.py | 94 ++++++++++++++++++++++++++++++------------ 2 files changed, 83 insertions(+), 38 deletions(-) diff --git a/tiny/rna/plotter.py b/tiny/rna/plotter.py index 06283976..7ab2ca14 100644 --- a/tiny/rna/plotter.py +++ b/tiny/rna/plotter.py @@ -291,7 +291,7 @@ def scatter_replicates(count_df: pd.DataFrame, samples: dict, output_prefix: str for samp, reps in samples.items(): for pair in itertools.combinations(reps, 2): rscat = aqplt.scatter_simple(count_df.loc[:,pair[0]], count_df.loc[:,pair[1]], - color='#B3B3B3', alpha=0.5, log_norm=True, rasterized=RASTER) + color='#B3B3B3', alpha=0.5, rasterized=RASTER) aqplt.set_square_scatter_view_lims(rscat, view_lims) aqplt.set_scatter_ticks(rscat) rscat.set_title(samp) @@ -347,10 +347,12 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un output_prefix: A string to use as a prefix for saving files classes: An optional dataframe of class(es) per feature. If provided, points are grouped by class show_unknown: If true, class "unknown" will be included if highlighting by classes + pval: The pvalue threshold for determining the outgroup """ if classes is not None: - uniq_classes = sorted(list(pd.unique(classes))) + uniq_classes = sorted(list(pd.unique(classes)), key=str.lower) + class_colors = aqplt.assign_class_colors(uniq_classes) aqplt.set_dge_class_legend_style() if not show_unknown and 'unknown' in uniq_classes: @@ -358,18 +360,17 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un for pair in dges: p1, p2 = pair.split("_vs_") + # Get list of differentially expressed features for this comparison pair dge_list = list(dges.index[dges[pair] < pval]) + # Create series of feature -> class relationships class_dges = classes.loc[dge_list] - + # Gather lists of features by class (listed in order corresponding to unique_classes) grp_args = [class_dges.index[class_dges == cls].tolist() for cls in uniq_classes] - layer_order = np.argsort([len(grp) for grp in grp_args])[::-1] - sorted_grps = [grp_args[i] for i in layer_order] - sorted_clss = [uniq_classes[i] for i in layer_order] + labels = uniq_classes + sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], *grp_args, colors=class_colors, + pval=pval, view_lims=view_lims, labels=labels, rasterized=RASTER) - labels = ['p ≥ %g' % pval] + sorted_clss - sscat = aqplt.scatter_grouped(count_df.loc[:,p1], count_df.loc[:,p2], view_lims, *sorted_grps, - log_norm=True, labels=labels, rasterized=RASTER) sscat.set_title('%s vs %s' % (p1, p2)) sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1) sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2) @@ -382,9 +383,11 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un grp_args = list(dges.index[dges[pair] < pval]) p1, p2 = pair.split("_vs_") - labels = ['p ≥ %g' % pval, 'p < %g' % pval] - sscat = aqplt.scatter_grouped(count_df.loc[:,p1], count_df.loc[:,p2], view_lims, grp_args, - log_norm=True, labels=labels, alpha=0.5, rasterized=RASTER) + labels = ['p < %g' % pval] + colors = aqplt.assign_class_colors(labels) + sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], grp_args, colors=colors, alpha=0.5, + pval=pval, view_lims=view_lims, labels=labels, rasterized=RASTER) + sscat.set_title('%s vs %s' % (p1, p2)) sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1) sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2) diff --git a/tiny/rna/plotterlib.py b/tiny/rna/plotterlib.py index 36fb7710..3444b3c7 100644 --- a/tiny/rna/plotterlib.py +++ b/tiny/rna/plotterlib.py @@ -160,13 +160,13 @@ def barh_proportion(self, prop_ds: pd.Series, max_prop=1.0, scale=2, **kwargs) - return cbar - def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False, **kwargs) -> plt.Axes: + def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_scale=True, **kwargs) -> plt.Axes: """Creates a simple scatter plot of counts. Args: count_x: A pandas dataframe/series of counts per feature (X axis) count_y: A pandas dataframe/series of counts per feature (Y axis) - log_norm: Plot on log scale rather than linear scale + log_scale: Plot on log scale rather than linear scale kwargs: Additional keyword arguments to pass to pyplot.Axes.scatter() Returns: @@ -178,7 +178,7 @@ def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False, ax: plt.Axes # log2 normalize data if requested - if log_norm: + if log_scale: # Set log2 scale ax.set_xscale('log', base=2) ax.set_yscale('log', base=2) @@ -211,51 +211,93 @@ def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False, return ax - def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame, - view_lims: Tuple[float, float] = None, *args, log_norm=False, labels=None, **kwargs): + def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame, *groups, colors: dict, pval=0.05, + view_lims: Tuple[float, float] = None, labels=None, **kwargs): """Creates a scatter plot with different groups highlighted. Args: count_x: A pandas dataframe/series of counts per feature (X axis) count_y: A pandas dataframe/series of counts per feature (Y axis) + groups: An iterable of lists each representing a group + labels: An iterable of labels corresponding to each group + colors: A dictionary of label -> color for each group view_lims: Optional plot view limits as tuple(min, max) - args: A list of features to highlight, can pass multiple lists - log_norm: whether or not the data should be log-normalized + pval: The p-value to use for the outgroup label + + Keyword Args: + log_scale: Data is plotted on log scale if true (default: true) kwargs: Additional arguments to pass to pyplot.Axes.scatter() Returns: gscat: A scatter plot containing groups highlighted with different colors """ - # Subset counts not in *args (for example, points with p-val above threshold) - count_x_base = count_x.drop(itertools.chain(*args)) - count_y_base = count_y.drop(itertools.chain(*args)) - - if labels is None: - labels = list(range(len(args))) + # Subset counts not in *groups (for example, points with p-val above threshold) + count_x_out = count_x.drop(itertools.chain(*groups)) + count_y_out = count_y.drop(itertools.chain(*groups)) - colors = iter(kwargs.get('colors', plt.rcParams['axes.prop_cycle'].by_key()['color'])) - argsit = iter(args) + outgroup = count_x_out.any() and count_y_out.any() + group_it = iter(groups) - if any([len(outgroup) == 0 for outgroup in [count_x_base, count_y_base]]): - # There is no outgroup, plot the first group with scatter_simple() to set scale and lines - group = next(argsit) - gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group], - log_norm=log_norm, color=next(colors), **kwargs) + if outgroup: + gscat = self.scatter_simple(count_x_out, count_y_out, color='#B3B3B3', **kwargs) else: - # Plot the outgroup in light grey (these are counts not in *args) - gscat = self.scatter_simple(count_x_base,count_y_base, log_norm=log_norm, color='#B3B3B3', **kwargs) + group = next(group_it) + gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group], **kwargs) - # Add each group to plot with a different color - for group in argsit: - gscat.scatter(count_x.loc[group], count_y.loc[group], color=next(colors), edgecolor='none', **kwargs) + # Add any remaining groups to the plot + for group in group_it: + gscat.scatter(count_x.loc[group], count_y.loc[group], edgecolor='none', **kwargs) + self.sort_point_groups_and_label(gscat, groups, labels, colors, outgroup, pval) self.set_square_scatter_view_lims(gscat, view_lims) self.set_scatter_ticks(gscat) - gscat.legend(labels=labels) + return gscat + @staticmethod + def sort_point_groups_and_label(axes: plt.Axes, groups, labels, colors, outgroup, pval): + """Sorts scatter groups so that less abundant groups are plotted on top to maximize visual representation. + After sorting, group colors and labels are assigned, and the legend is created.""" + + lorder = np.argsort([len(grp) for grp in groups])[::-1] # Label index of groups sorted largest -> smallest + ordmax = lorder.max() # Length of groups, sans outgroup + zorder = ordmax - lorder # Z-order for groups (largest values on top) + gorder = lorder + 1 if outgroup else lorder # Group index of sorted groups, sans outgroup + + if outgroup: + axes.collections[0].set_label('p ≥ %g' % pval) + if labels is None: + labels = list(range(len(groups))) + + for G, L, Z in zip(gorder, lorder, zorder): + points = axes.collections[G] + label = labels[L] + points.set_facecolor(colors[label]) + points.set_label(label) + points.set_zorder(Z) + + # Ensure lines remain on top of points + for line in axes.lines: + line.set_zorder(ordmax + 1) + + axes.legend() + + @staticmethod + def assign_class_colors(classes): + """Assigns a color to each class for consistency across samples""" + + stylesheet_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + if len(classes) <= len(stylesheet_colors): + colors = iter(stylesheet_colors) + else: + colors = iter(plt.get_cmap("tab20")) + + return {cls: next(colors) for cls in classes} + def set_dge_class_legend_style(self): + """Widens the "scatter" figure and moves plot to the left to accommodate legend""" + expand_width_inches = 3 fig, scatter = self.reuse_subplot("scatter")