diff --git a/START_HERE/run_config.yml b/START_HERE/run_config.yml index 7dfe47da..9358ee6a 100644 --- a/START_HERE/run_config.yml +++ b/START_HERE/run_config.yml @@ -291,6 +291,11 @@ plot_unknown_class: "_UNKNOWN_" ##-- Use this label in class plots for unassigned counts --## plot_unassigned_class: "_UNASSIGNED_" +##-- Optionally filter the classes in class scatter plots --## +plot_class_scatter_filter: + style: include # Choose: include or exclude + classes: [] # Add classes between [ and ], separated by comma + ######----------------------------- OUTPUT DIRECTORIES ------------------------------###### # @@ -367,4 +372,17 @@ run_deseq: True # # The following configuration settings are automatically derived from the Features Sheet # -######-------------------------------------------------------------------------------###### \ No newline at end of file +######-------------------------------------------------------------------------------###### + + + +######--------------------------- DERIVED FROM RUN CONFIG ---------------------------###### +# +# The following configuration settings are automatically derived from this file +# +######-------------------------------------------------------------------------------###### + +##-- Utilized by tiny-plot --## +# Filters for class scatter plots +plot_class_scatter_filter_include: [] +plot_class_scatter_filter_exclude: [] \ No newline at end of file diff --git a/doc/Parameters.md b/doc/Parameters.md index b563f34d..6bc2f52a 100644 --- a/doc/Parameters.md +++ b/doc/Parameters.md @@ -243,13 +243,22 @@ The min and/or max bounds for plotted lengths can be set with this option. See [ The labels that should be used for special groups in `class_charts` and `sample_avg_scatter_by_dge_class` plots. The "unknown" class group represents counts which were assigned by a Features Sheet rule which lacked a "Classify as..." label. The "unassigned" class group represents counts which weren't assigned to a feature. +### Filtering Classes in DGE Class Scatter Plots +| Run Config Key | Commandline Argument | +|----------------------------|----------------------| +| plot_class_scatter_filter: | `--classes-include` | +| | `--classes-exclude` | + +If an inclusive filter is used, then only the classes in the list, if present, are shown. If an exclusive filter is used, then the listed classes are omitted from the plot. This behavior extends to features whose P value is above threshold. In the Run Config, the filter type can be set with the `style:` sub-key, and the desired list of classes for the filter can be provided between the brackets of the `classes:` sub-key + ### Full tiny-plot Help String ``` tiny-plot [-rc RAW_COUNTS] [-nc NORM_COUNTS] [-uc RULE_COUNTS] [-ss STAT] [-dge COMPARISON [COMPARISON ...]] [-len 5P_LEN [5P_LEN ...]] [-o PREFIX] [-pv VALUE] [-s MPLSTYLE] [-v] [-ldi VALUE] [-lda VALUE] [-una LABEL] - [-unk LABEL] -p PLOT [PLOT ...] + [-unk LABEL] [-ic CLASS [CLASS ...] | -ec CLASS [CLASS ...]] + -p PLOT [PLOT ...] This script produces basic static plots for publication as part of the tinyRNA workflow. @@ -267,7 +276,7 @@ Required arguments: • rule_charts: A barchart showing percentages of counts by matched rule. • class_charts: A barchart showing percentages of - counts per Classification. + counts per classification. • replicate_scatter: A scatter plot comparing replicates for all count files given. • sample_avg_scatter_by_dge: A scatter plot comparing @@ -315,4 +324,10 @@ Optional arguments: Use this label in class-related plots for counts which were assigned by rules lacking a "Classify as..." value + -ic CLASS [CLASS ...], --classes-include CLASS [CLASS ...] + Only include these classes, if present, in class + scatter plots (applies regardless of P value) + -ec CLASS [CLASS ...], --classes-exclude CLASS [CLASS ...] + Omit these classes, if present, from class scatter + plots (applies regardless of P value) ``` \ No newline at end of file diff --git a/doc/tiny-plot.md b/doc/tiny-plot.md index 4b9ed158..62a027fd7 100644 --- a/doc/tiny-plot.md +++ b/doc/tiny-plot.md @@ -109,7 +109,7 @@ The P value cutoff can be changed using the [Run Config or commandline arguments ## sample_avg_scatter_by_dge_class -The previous plot type can be extended to group and color differentially expressed features by class. Classes are sorted by abundance before plotting to maximize representation. +The previous plot type can be extended to group and color differentially expressed features by class. Classes are sorted by abundance before plotting to maximize representation. You can also filter the classes displayed using [plot_class_scatter_filter](Parameters.md#filtering-classes-in-dge-class-scatter-plots)

sample_avg_scatter_by_dge_class diff --git a/tests/unit_tests_plotter.py b/tests/unit_tests_plotter.py index 8e12cce5..09b03884 100644 --- a/tests/unit_tests_plotter.py +++ b/tests/unit_tests_plotter.py @@ -1,8 +1,6 @@ -import sys import unittest -from unittest.mock import patch, call +from unittest.mock import patch -import numpy as np import pandas as pd from pandas.testing import assert_frame_equal from pkg_resources import resource_filename @@ -16,11 +14,32 @@ class MyTestCase(unittest.TestCase): def setUpClass(cls): cls.stylesheet = resource_filename('tiny', 'templates/tinyrna-light.mplstyle') + #====== HELPER METHODS =================================================== + def get_label_width_pairs_from_annotations_mock(self, annotations): bar_widths = [i[1]['xycoords'].get_width() for i in annotations.call_args_list] bar_labels = [i[0][0] for i in annotations.call_args_list] return list(zip(bar_labels, bar_widths)) + def aqplt_mock(self): + return patch( + 'tiny.rna.plotter.aqplt', + lib.plotterlib(resource_filename('tiny', 'templates/tinyrna-light.mplstyle')) + ) + + def get_empty_scatter_dge_dataframes(self): + counts = pd.DataFrame( + columns=['Feature ID', 'Classifier', 'ConditionA', 'ConditionB'] + ).set_index(['Feature ID', 'Classifier']) + + dge = pd.DataFrame( + columns=['Feature ID', 'Classifier', 'ConditionA_vs_ConditionB'] + ).set_index(['Feature ID', 'Classifier']) + + return counts, dge + + #====== TESTS ============================================================= + """Are class counts properly calculated?""" def test_class_counts(self): @@ -128,6 +147,37 @@ def test_proportion_chart_percentage_unassigned(self): assert_frame_equal(actual_below_thresh, expected_below_thresh, check_dtype=False, check_like=True) + """Does scatter_by_dge do the right thing when DataFrame inputs are empty?""" + + def test_scatter_by_dge_empty_dataframes(self): + counts, dge = self.get_empty_scatter_dge_dataframes() + + with patch('tiny.rna.plotter.save_plot') as save_plot, self.aqplt_mock(): + plotter.scatter_by_dge(counts, dge, 'dummy_prefix', (0, 0)) + + save_plot.assert_not_called() + + """Does scatter_by_dge_class do the right thing when DataFrame inputs are empty?""" + + def test_scatter_by_dge_class_empty_dataframes(self): + counts, dge = self.get_empty_scatter_dge_dataframes() + + with patch('tiny.rna.plotter.save_plot') as save_plot, self.aqplt_mock(): + plotter.scatter_by_dge_class(counts, dge, 'dummy_prefix', (0, 0)) + + save_plot.assert_not_called() + + """Does scatter_by_dge_class properly handle empty inclusive filter lists?""" + + def test_scatter_dge_class_empty_inclusive_filter(self): + counts, dge = self.get_empty_scatter_dge_dataframes() + + with patch('tiny.rna.plotter.plotterlib.scatter_grouped') as scatter, self.aqplt_mock(): + plotter.scatter_by_dge_class(counts, dge, 'dummy_prefix', (0, 0), include=[]) + scatter.assert_not_called() + plotter.scatter_by_dge_class(counts, dge, 'dummy_prefix', (0, 0), exclude=[]) + scatter.assert_not_called() + if __name__ == '__main__': unittest.main() diff --git a/tiny/cwl/tools/tiny-plot.cwl b/tiny/cwl/tools/tiny-plot.cwl index 8bacd6bf..1edcc7a0 100644 --- a/tiny/cwl/tools/tiny-plot.cwl +++ b/tiny/cwl/tools/tiny-plot.cwl @@ -88,6 +88,22 @@ inputs: prefix: -una doc: 'Use this label in class-related plots for unassigned counts' + classes_include: + type: string[]? + inputBinding: + prefix: -ic + doc: \ + 'Only include these classes, if present, in class scatter ' + 'plots (applies regardless of P value)' + + classes_exclude: + type: string[]? + inputBinding: + prefix: -ec + doc: \ + 'Omit these classes, if present, from class scatter plots ' + '(applies regardless of P value)' + out_prefix: type: string? inputBinding: diff --git a/tiny/cwl/workflows/tinyrna_wf.cwl b/tiny/cwl/workflows/tinyrna_wf.cwl index 39fe0ccf..ae5be1af 100644 --- a/tiny/cwl/workflows/tinyrna_wf.cwl +++ b/tiny/cwl/workflows/tinyrna_wf.cwl @@ -103,6 +103,8 @@ inputs: plot_pval: float? plot_unknown_class: string? plot_unassigned_class: string? + plot_class_scatter_filter_include: string[]? + plot_class_scatter_filter_exclude: string[]? # output directory names dir_name_bt_build: string @@ -258,6 +260,8 @@ steps: $(self.length ? self[0] : null) unknown_class_label: plot_unknown_class unassigned_class_label: plot_unassigned_class + classes_include: plot_class_scatter_filter_include + classes_exclude: plot_class_scatter_filter_exclude dge_pval: plot_pval style_sheet: plot_style_sheet out_prefix: run_name diff --git a/tiny/rna/configuration.py b/tiny/rna/configuration.py index 21ebe081..0f47264b 100644 --- a/tiny/rna/configuration.py +++ b/tiny/rna/configuration.py @@ -136,6 +136,29 @@ def from_here(self, destination: Union[str,dict,None], origin: Union[str,dict,No else: return destination + def setup_step_inputs(self): + """For now, only tiny-plot requires additional setup for step inputs + This function is called at both startup and resume""" + + def setup_tiny_plot_inputs(): + cs_filter = 'plot_class_scatter_filter' + style_req = ['include', 'exclude'] + classes = self.get(cs_filter, {}).get('classes') # backward compatibility + if not classes: return + + # Validate filter style + style = self[cs_filter]['style'].lower() + assert style in style_req, \ + f'{cs_filter} -> style: must be {" or ".join(style_req)}.' + + # Assign the workflow key and reset the other filter(s) + self[f"{cs_filter}_{style}"] = classes.copy() + style_req.remove(style) + for style in style_req: + self[f"{cs_filter}_{style}"] = [] + + setup_tiny_plot_inputs() + def create_run_directory(self) -> str: """Create the destination directory for pipeline outputs""" run_dir = self["run_directory"] @@ -192,6 +215,7 @@ def __init__(self, config_file: str, validate_inputs=False): self.setup_ebwt_idx() self.process_samples_sheet() self.process_features_sheet() + self.setup_step_inputs() if validate_inputs: self.validate_inputs() def load_paths_config(self): diff --git a/tiny/rna/plotter.py b/tiny/rna/plotter.py index 2a01b8ab..2ff72db2 100644 --- a/tiny/rna/plotter.py +++ b/tiny/rna/plotter.py @@ -14,14 +14,14 @@ import re from collections import defaultdict -from typing import Dict, Union, Tuple, DefaultDict +from typing import Dict, Union, Tuple, DefaultDict, Iterable from pkg_resources import resource_filename from tiny.rna.plotterlib import plotterlib from tiny.rna.util import report_execution_time, make_filename, SmartFormatter, timestamp_format, add_transparent_help -aqplt: plotterlib -RASTER: bool +aqplt: plotterlib = None +RASTER: bool = True def get_args(): @@ -71,6 +71,15 @@ def get_args(): help='Use this label in class-related plots for counts which were ' 'assigned by rules lacking a "Classify as..." value') + # Class filtering options + mutex_class_filter = optional_args.add_mutually_exclusive_group() + mutex_class_filter.add_argument('-ic', '--classes-include', metavar='CLASS', nargs='+', type=str, + help='Only include these classes, if present, in class scatter ' + 'plots (applies regardless of P value)') + mutex_class_filter.add_argument('-ec', '--classes-exclude', metavar='CLASS', nargs='+', type=str, + help='Omit these classes, if present, from class scatter plots ' + '(applies regardless of P value)') + # Required arguments required_args.add_argument('-p', '--plots', metavar='PLOT', required=True, nargs='+', help="R|List of plots to create. Options: \n" @@ -339,56 +348,113 @@ def load_dge_tables(comparisons: list, class_fillna: str) -> pd.DataFrame: return de_table -def scatter_dges(count_df, dges, output_prefix, view_lims, classes=None, pval=0.05): - """Creates PDFs of all pairwise comparison scatter plots from a count table. - Can highlight classes and/or differentially expressed genes as different colors. +def filter_dge_classes(count_df: pd.DataFrame, dges: pd.DataFrame, include: Iterable = None, exclude: Iterable = None): + """Filters features by classification in counts and DGE tables. + Arguments `include` and `exclude` are mutually exclusive; providing both + arguments will result in an error. Args: - count_df: A dataframe of counts per feature with multiindex (feature ID, classifier) - dges: A dataframe of differential gene table output to highlight + count_df: A dataframe with an index comprised of (feature ID, classifier) + dges: A dataframe with an index comprised of (feature ID, classifier) + include: An iterable of classifiers to allow + exclude: An iterable of classifiers to exclude + + Returns: A filtered counts_avg_df and a filtered dge table + """ + + if not (include or exclude): + return count_df, dges + elif include and exclude: + raise ValueError("Include/exclude filters are mutually exclusive.") + + if include is not None: + include_lc = [cls.lower() for cls in include] + mask = count_df.index.get_level_values('Classifier').str.lower().isin(include_lc) + elif exclude is not None: + exclude_lc = [cls.lower() for cls in exclude] + mask = ~count_df.index.get_level_values('Classifier').str.lower().isin(exclude_lc) + else: + mask = None # to appease linter + + return count_df[mask], dges[mask] + + +def scatter_by_dge_class(counts_avg_df, dges, output_prefix, view_lims, include=None, exclude=None, pval=0.05): + """Creates PDFs of all pairwise comparison scatter plots with differentially + expressed features colored by class. Counts for features with P value >= `pval` + will be assigned the color grey. + + Args: + counts_avg_df: A dataframe of normalized counts per multiindex (feature ID, classification) that + have been averaged across replicates within each condition. + dges: A differential gene expression dataframe with multiindex (feature ID, classification), + and columns containing adjusted P values for each pairwise comparison. Each column should be + labeled as "conditionA_vs_conditionB" where conditionB is the untreated condition. view_lims: A tuple of (min, max) data bounds for the plot view + include: A list of classes to plot, if present (default: all classes) + exclude: A list of classes to omit from plots (default: no classes) output_prefix: A string to use as a prefix for saving files - classes: An optional feature-class multiindex. If provided, points are grouped by class - pval: The pvalue threshold for determining the outgroup + pval: The P value threshold for determining the outgroup """ - if classes is not None: - uniq_classes = pd.unique(classes.get_level_values(1)) - class_colors = aqplt.assign_class_colors(uniq_classes) - aqplt.set_dge_class_legend_style() + counts_avg_df, dges = filter_dge_classes(counts_avg_df, dges, include, exclude) + if counts_avg_df.empty or dges.empty: return - for pair in dges: - p1, p2 = pair.split("_vs_") - dge_dict = dges[dges[pair] < pval].groupby(level=1).groups + uniq_classes = pd.unique(counts_avg_df.index.get_level_values(1)) + class_colors = aqplt.assign_class_colors(uniq_classes) + aqplt.set_dge_class_legend_style() - labels, grp_args = zip(*dge_dict.items()) if dge_dict else ((), ()) - sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], *grp_args, - colors=class_colors, pval=pval, view_lims=view_lims, labels=labels, - rasterized=RASTER) + for pair in dges: + ut, tr = pair.split("_vs_") # untreated, treated + dge_classes = dges[dges[pair] < pval].groupby(level=1).groups - sscat.set_title('%s vs %s' % (p1, p2)) - sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1) - sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2) - sscat.get_legend().set_bbox_to_anchor((1, 1)) - pdf_name = make_filename([output_prefix, pair, 'scatter_by_dge_class'], ext='.pdf') - save_plot(sscat, "scatter_by_dge_class", pdf_name) + labels, grp_args = zip(*dge_classes.items()) if dge_classes else ((), ()) + sscat = aqplt.scatter_grouped(counts_avg_df.loc[:, ut], counts_avg_df.loc[:, tr], *grp_args, + colors=class_colors, pval=pval, view_lims=view_lims, labels=labels, + rasterized=RASTER) - else: - for pair in dges: - grp_args = list(dges.index[dges[pair] < pval]) - p1, p2 = pair.split("_vs_") + sscat.set_title('%s vs %s' % (tr, ut)) + sscat.set_xlabel("Log$_{2}$ normalized reads in " + ut) + sscat.set_ylabel("Log$_{2}$ normalized reads in " + tr) + sscat.get_legend().set_bbox_to_anchor((1, 1)) + pdf_name = make_filename([output_prefix, pair, 'scatter_by_dge_class'], ext='.pdf') + save_plot(sscat, "scatter_by_dge_class", pdf_name) - labels = ['p < %g' % pval] if grp_args else [] - colors = aqplt.assign_class_colors(labels) - sscat = aqplt.scatter_grouped(count_df.loc[:, p1], count_df.loc[:, p2], grp_args, - colors=colors, alpha=0.5, pval=pval, view_lims=view_lims, labels=labels, - rasterized=RASTER) - sscat.set_title('%s vs %s' % (p1, p2)) - sscat.set_xlabel("Log$_{2}$ normalized reads in " + p1) - sscat.set_ylabel("Log$_{2}$ normalized reads in " + p2) - pdf_name = make_filename([output_prefix, pair, 'scatter_by_dge'], ext='.pdf') - save_plot(sscat, 'scatter_by_dge', pdf_name) +def scatter_by_dge(counts_avg_df, dges, output_prefix, view_lims, pval=0.05): + """Creates PDFs of all pairwise comparison scatter plots with differentially + expressed features highlighted. Counts for features with P value >= `pval` + will be assigned the color grey. + + Args: + counts_avg_df: A dataframe of normalized counts per multiindex (feature ID, classification) that + have been averaged across replicates within each condition. + dges: A differential gene expression dataframe with multiindex (feature ID, classification), + and columns containing adjusted P values for each pairwise comparison. Each column should be + labeled as "conditionA_vs_conditionB" where conditionB is the untreated condition. + view_lims: A tuple of (min, max) data bounds for the plot view + output_prefix: A string to use as a prefix for saving files + pval: The pvalue threshold for determining the outgroup + """ + + if counts_avg_df.empty or dges.empty: + return + + for pair in dges: + grp_args = dges.index[dges[pair] < pval] + ut, tr = pair.split("_vs_") # untreated, treated + + labels = ['p < %g' % pval] if not grp_args.empty else [] + colors = aqplt.assign_class_colors(labels) + sscat = aqplt.scatter_grouped(counts_avg_df.loc[:, ut], counts_avg_df.loc[:, tr], grp_args, + colors=colors, alpha=0.5, pval=pval, view_lims=view_lims, labels=labels, + rasterized=RASTER) + + sscat.set_title('%s vs %s' % (tr, ut)) + sscat.set_xlabel("Log$_{2}$ normalized reads in " + ut) + sscat.set_ylabel("Log$_{2}$ normalized reads in " + tr) + pdf_name = make_filename([output_prefix, pair, 'scatter_by_dge'], ext='.pdf') + save_plot(sscat, 'scatter_by_dge', pdf_name) def load_raw_counts(raw_counts_file: str, class_fillna: str) -> pd.DataFrame: @@ -437,21 +503,6 @@ def set_counts_table_multiindex(counts: pd.DataFrame, fillna: str) -> pd.DataFra return counts.set_index([level0, level1]) -def get_flat_classes(counts_df: pd.DataFrame) -> pd.Index: - """Features with multiple associated classes are returned in flattened form - with one class per entry, yielding multiple entries for these features. During - earlier versions this required some processing, but now we can simply return - the multiindex of the counts_df. - - Args: - counts_df: A DataFrame with a multiindex of (feature ID, feature class) - Returns: - The counts_df multiindex - """ - - return counts_df.index - - def get_class_counts(raw_counts_df: pd.DataFrame) -> pd.DataFrame: """Calculates class counts from level 1 of the raw counts multiindex @@ -548,7 +599,7 @@ def setup(args: argparse.Namespace) -> dict: "de_table_df", "avg_view_lims"], 'sample_avg_scatter_by_dge_class': ["norm_counts_df", "sample_rep_dict", "norm_counts_avg_df", - "feat_classes_df", "de_table_df", "avg_view_lims"] + "de_table_df", "avg_view_lims"] } # These are frozen function pointers; both the function and its @@ -563,7 +614,6 @@ def setup(args: argparse.Namespace) -> dict: 'de_table_df': lambda: load_dge_tables(args.dge_tables, args.unknown_class), 'sample_rep_dict': lambda: get_sample_rep_dict(fetched["norm_counts_df"]), 'norm_counts_avg_df': lambda: get_sample_averages(fetched["norm_counts_df"], fetched["sample_rep_dict"]), - 'feat_classes_df': lambda: get_flat_classes(fetched["norm_counts_df"]), 'class_counts_df': lambda: get_class_counts(fetched["raw_counts_df"]), 'avg_view_lims': lambda: aqplt.get_scatter_view_lims(fetched["norm_counts_avg_df"]), 'norm_view_lims': lambda: aqplt.get_scatter_view_lims(fetched["norm_counts_df"].select_dtypes(['number'])) @@ -611,13 +661,13 @@ def main(): arg = (inputs["norm_counts_df"], inputs["sample_rep_dict"], args.out_prefix, inputs["norm_view_lims"]) kwd = {} elif plot == 'sample_avg_scatter_by_dge': - func = scatter_dges + func = scatter_by_dge arg = (inputs["norm_counts_avg_df"], inputs["de_table_df"], args.out_prefix, inputs["avg_view_lims"]) kwd = {"pval": args.p_value} elif plot == 'sample_avg_scatter_by_dge_class': - func = scatter_dges + func = scatter_by_dge_class arg = (inputs["norm_counts_avg_df"], inputs["de_table_df"], args.out_prefix, inputs["avg_view_lims"]) - kwd = {"classes": inputs["feat_classes_df"], "pval": args.p_value} + kwd = {"pval": args.p_value, "include": args.classes_include, "exclude": args.classes_exclude} else: print('Plot type %s not recognized, please check the -p/--plot arguments' % plot) continue diff --git a/tiny/rna/resume.py b/tiny/rna/resume.py index f1e2745d..e3cd7840 100644 --- a/tiny/rna/resume.py +++ b/tiny/rna/resume.py @@ -39,6 +39,7 @@ def __init__(self, processed_config, workflow, steps, entry_inputs): self.paths = self.load_paths_config() self.assimilate_paths_file() + self.setup_step_inputs() self._create_truncated_workflow() self._rebuild_entry_inputs() diff --git a/tiny/templates/run_config_template.yml b/tiny/templates/run_config_template.yml index c6f6b77e..272c0551 100644 --- a/tiny/templates/run_config_template.yml +++ b/tiny/templates/run_config_template.yml @@ -291,6 +291,11 @@ plot_unknown_class: "_UNKNOWN_" ##-- Use this label in class plots for unassigned counts --## plot_unassigned_class: "_UNASSIGNED_" +##-- Optionally filter the classes in class scatter plots --## +plot_class_scatter_filter: + style: include # Choose: include or exclude + classes: [] # Add classes between [ and ], separated by comma + ######----------------------------- OUTPUT DIRECTORIES ------------------------------###### # @@ -367,4 +372,17 @@ run_deseq: True # # The following configuration settings are automatically derived from the Features Sheet # -######-------------------------------------------------------------------------------###### \ No newline at end of file +######-------------------------------------------------------------------------------###### + + + +######--------------------------- DERIVED FROM RUN CONFIG ---------------------------###### +# +# The following configuration settings are automatically derived from this file +# +######-------------------------------------------------------------------------------###### + +##-- Utilized by tiny-plot --## +# Filters for class scatter plots +plot_class_scatter_filter_include: [] +plot_class_scatter_filter_exclude: [] \ No newline at end of file