Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions tiny/rna/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def scatter_replicates(count_df: pd.DataFrame, samples: dict, output_prefix: str
for samp, reps in samples.items():
for pair in itertools.combinations(reps, 2):
rscat = aqplt.scatter_simple(count_df.loc[:,pair[0]], count_df.loc[:,pair[1]],
color='#B3B3B3', alpha=0.5, log_norm=True, rasterized=RASTER)
color='#B3B3B3', alpha=0.5, rasterized=RASTER)
aqplt.set_square_scatter_view_lims(rscat, view_lims)
aqplt.set_scatter_ticks(rscat)
rscat.set_title(samp)
Expand Down Expand Up @@ -347,29 +347,30 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un
output_prefix: A string to use as a prefix for saving files
classes: An optional dataframe of class(es) per feature. If provided, points are grouped by class
show_unknown: If true, class "unknown" will be included if highlighting by classes
pval: The pvalue threshold for determining the outgroup
"""

if classes is not None:
uniq_classes = sorted(list(pd.unique(classes)))
uniq_classes = sorted(list(pd.unique(classes)), key=str.lower)
class_colors = aqplt.assign_class_colors(uniq_classes)
aqplt.set_dge_class_legend_style()

if not show_unknown and 'unknown' in uniq_classes:
uniq_classes.remove('unknown')

for pair in dges:
p1, p2 = pair.split("_vs_")
# Get list of differentially expressed features for this comparison pair
dge_list = list(dges.index[dges[pair] < pval])
# Create series of feature -> class relationships
class_dges = classes.loc[dge_list]

# Gather lists of features by class (listed in order corresponding to unique_classes)
grp_args = [class_dges.index[class_dges == cls].tolist() for cls in uniq_classes]

layer_order = np.argsort([len(grp) for grp in grp_args])[::-1]
sorted_grps = [grp_args[i] for i in layer_order]
sorted_clss = [uniq_classes[i] for i in layer_order]
labels = uniq_classes
sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], *grp_args, colors=class_colors,
pval=pval, view_lims=view_lims, labels=labels, rasterized=RASTER)

labels = ['p ≥ %g' % pval] + sorted_clss
sscat = aqplt.scatter_grouped(count_df.loc[:,p1], count_df.loc[:,p2], view_lims, *sorted_grps,
log_norm=True, labels=labels, rasterized=RASTER)
sscat.set_title('%s vs %s' % (p1, p2))
sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1)
sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2)
Expand All @@ -382,9 +383,11 @@ def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, show_un
grp_args = list(dges.index[dges[pair] < pval])
p1, p2 = pair.split("_vs_")

labels = ['p ≥ %g' % pval, 'p < %g' % pval]
sscat = aqplt.scatter_grouped(count_df.loc[:,p1], count_df.loc[:,p2], view_lims, grp_args,
log_norm=True, labels=labels, alpha=0.5, rasterized=RASTER)
labels = ['p < %g' % pval]
colors = aqplt.assign_class_colors(labels)
sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], grp_args, colors=colors, alpha=0.5,
pval=pval, view_lims=view_lims, labels=labels, rasterized=RASTER)

sscat.set_title('%s vs %s' % (p1, p2))
sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1)
sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2)
Expand Down
94 changes: 68 additions & 26 deletions tiny/rna/plotterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def barh_proportion(self, prop_ds: pd.Series, max_prop=1.0, scale=2, **kwargs) -

return cbar

def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False, **kwargs) -> plt.Axes:
def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_scale=True, **kwargs) -> plt.Axes:
"""Creates a simple scatter plot of counts.

Args:
count_x: A pandas dataframe/series of counts per feature (X axis)
count_y: A pandas dataframe/series of counts per feature (Y axis)
log_norm: Plot on log scale rather than linear scale
log_scale: Plot on log scale rather than linear scale
kwargs: Additional keyword arguments to pass to pyplot.Axes.scatter()

Returns:
Expand All @@ -178,7 +178,7 @@ def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False,
ax: plt.Axes

# log2 normalize data if requested
if log_norm:
if log_scale:
# Set log2 scale
ax.set_xscale('log', base=2)
ax.set_yscale('log', base=2)
Expand Down Expand Up @@ -211,51 +211,93 @@ def scatter_simple(self, count_x: pd.Series, count_y: pd.Series, log_norm=False,

return ax

def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame,
view_lims: Tuple[float, float] = None, *args, log_norm=False, labels=None, **kwargs):
def scatter_grouped(self, count_x: pd.DataFrame, count_y: pd.DataFrame, *groups, colors: dict, pval=0.05,
view_lims: Tuple[float, float] = None, labels=None, **kwargs):
"""Creates a scatter plot with different groups highlighted.

Args:
count_x: A pandas dataframe/series of counts per feature (X axis)
count_y: A pandas dataframe/series of counts per feature (Y axis)
groups: An iterable of lists each representing a group
labels: An iterable of labels corresponding to each group
colors: A dictionary of label -> color for each group
view_lims: Optional plot view limits as tuple(min, max)
args: A list of features to highlight, can pass multiple lists
log_norm: whether or not the data should be log-normalized
pval: The p-value to use for the outgroup label

Keyword Args:
log_scale: Data is plotted on log scale if true (default: true)
kwargs: Additional arguments to pass to pyplot.Axes.scatter()

Returns:
gscat: A scatter plot containing groups highlighted with different colors
"""

# Subset counts not in *args (for example, points with p-val above threshold)
count_x_base = count_x.drop(itertools.chain(*args))
count_y_base = count_y.drop(itertools.chain(*args))

if labels is None:
labels = list(range(len(args)))
# Subset counts not in *groups (for example, points with p-val above threshold)
count_x_out = count_x.drop(itertools.chain(*groups))
count_y_out = count_y.drop(itertools.chain(*groups))

colors = iter(kwargs.get('colors', plt.rcParams['axes.prop_cycle'].by_key()['color']))
argsit = iter(args)
outgroup = count_x_out.any() and count_y_out.any()
group_it = iter(groups)

if any([len(outgroup) == 0 for outgroup in [count_x_base, count_y_base]]):
# There is no outgroup, plot the first group with scatter_simple() to set scale and lines
group = next(argsit)
gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group],
log_norm=log_norm, color=next(colors), **kwargs)
if outgroup:
gscat = self.scatter_simple(count_x_out, count_y_out, color='#B3B3B3', **kwargs)
else:
# Plot the outgroup in light grey (these are counts not in *args)
gscat = self.scatter_simple(count_x_base,count_y_base, log_norm=log_norm, color='#B3B3B3', **kwargs)
group = next(group_it)
gscat = self.scatter_simple(count_x.loc[group], count_y.loc[group], **kwargs)

# Add each group to plot with a different color
for group in argsit:
gscat.scatter(count_x.loc[group], count_y.loc[group], color=next(colors), edgecolor='none', **kwargs)
# Add any remaining groups to the plot
for group in group_it:
gscat.scatter(count_x.loc[group], count_y.loc[group], edgecolor='none', **kwargs)

self.sort_point_groups_and_label(gscat, groups, labels, colors, outgroup, pval)
self.set_square_scatter_view_lims(gscat, view_lims)
self.set_scatter_ticks(gscat)
gscat.legend(labels=labels)

return gscat

@staticmethod
def sort_point_groups_and_label(axes: plt.Axes, groups, labels, colors, outgroup, pval):
"""Sorts scatter groups so that less abundant groups are plotted on top to maximize visual representation.
After sorting, group colors and labels are assigned, and the legend is created."""

lorder = np.argsort([len(grp) for grp in groups])[::-1] # Label index of groups sorted largest -> smallest
ordmax = lorder.max() # Length of groups, sans outgroup
zorder = ordmax - lorder # Z-order for groups (largest values on top)
gorder = lorder + 1 if outgroup else lorder # Group index of sorted groups, sans outgroup

if outgroup:
axes.collections[0].set_label('p ≥ %g' % pval)
if labels is None:
labels = list(range(len(groups)))

for G, L, Z in zip(gorder, lorder, zorder):
points = axes.collections[G]
label = labels[L]
points.set_facecolor(colors[label])
points.set_label(label)
points.set_zorder(Z)

# Ensure lines remain on top of points
for line in axes.lines:
line.set_zorder(ordmax + 1)

axes.legend()

@staticmethod
def assign_class_colors(classes):
"""Assigns a color to each class for consistency across samples"""

stylesheet_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
if len(classes) <= len(stylesheet_colors):
colors = iter(stylesheet_colors)
else:
colors = iter(plt.get_cmap("tab20"))

return {cls: next(colors) for cls in classes}

def set_dge_class_legend_style(self):
"""Widens the "scatter" figure and moves plot to the left to accommodate legend"""

expand_width_inches = 3

fig, scatter = self.reuse_subplot("scatter")
Expand Down