diff --git a/tests/unit_tests_configuration.py b/tests/unit_tests_configuration.py index 0d97b443..b2f9fe34 100644 --- a/tests/unit_tests_configuration.py +++ b/tests/unit_tests_configuration.py @@ -7,6 +7,7 @@ from tiny.rna.configuration import Configuration, SamplesSheet, PathsFile from unit_test_helpers import csv_factory, paths_template_file, make_paths_file +from tiny.rna.util import r_reserved_keywords class BowtieIndexesTest(unittest.TestCase): @@ -209,6 +210,22 @@ def test_validate_fastq_filepath(self): patch('tiny.rna.configuration.os.path.isfile', return_value=True): SamplesSheet('mock_filename') + """Does validate_r_safe_sample_groups detect group names that will cause namespace collisions in R?""" + + def test_validate_r_safe_sample_groups(self): + non_alphanum_chars = [bad.join(('a', 'b')) for bad in "~!@#$%^&*()+-=`<>?/,:;\"'[]{}\| \t\n\r\f\v"] + leading_dot_number = [".0", "X.0"] + + for bad in [non_alphanum_chars, leading_dot_number]: + msg = " ≈ ".join(bad) + with self.assertRaisesRegex(AssertionError, msg): + SamplesSheet.validate_r_safe_sample_groups(dict.fromkeys(bad)) + + for kwd in r_reserved_keywords: + bad = (kwd, kwd + '.') + msg = " ≈ ".join(bad) + with self.assertRaisesRegex(AssertionError, msg): + SamplesSheet.validate_r_safe_sample_groups(dict.fromkeys(bad)) class PathsFileTest(unittest.TestCase): diff --git a/tiny/rna/configuration.py b/tiny/rna/configuration.py index 0f47264b..f53ef50b 100644 --- a/tiny/rna/configuration.py +++ b/tiny/rna/configuration.py @@ -8,12 +8,12 @@ import re from pkg_resources import resource_filename -from collections import Counter, OrderedDict +from collections import Counter, OrderedDict, defaultdict from typing import Union, Any, Optional, List from glob import glob from tiny.rna.counter.validation import GFFValidator -from tiny.rna.util import get_timestamp +from tiny.rna.util import get_timestamp, get_r_safename class ConfigBase: @@ -586,6 +586,8 @@ def validate_control_group(self, is_control: bool, group: str): @staticmethod def validate_deseq_compatibility(sample_groups: Counter) -> bool: + SamplesSheet.validate_r_safe_sample_groups(sample_groups) + total_samples = sum(sample_groups.values()) total_coefficients = len(sample_groups) degrees_of_freedom = total_samples - total_coefficients @@ -598,6 +600,21 @@ def validate_deseq_compatibility(sample_groups: Counter) -> bool: else: return True + @staticmethod + def validate_r_safe_sample_groups(sample_groups: Counter): + """Determine the "syntactically valid" translation of each group name to ensure + that two groups won't share the same name once translated in tiny-deseq.r""" + + safe_names = defaultdict(list) + for group in sample_groups: + safe_names[get_r_safename(group)].append(group) + + collisions = [' ≈ '.join(cluster) for cluster in safe_names.values() if len(cluster) > 1] + + assert len(collisions) == 0, \ + "The following group names are too similar and will cause a namespace collision in R:\n" \ + + '\n'.join(collisions) + @staticmethod def get_sample_basename(filename): root, _ = os.path.splitext(filename) diff --git a/tiny/rna/util.py b/tiny/rna/util.py index 59c1149c..954ea668 100644 --- a/tiny/rna/util.py +++ b/tiny/rna/util.py @@ -95,17 +95,33 @@ def make_filename(args, ext='.csv'): return '_'.join([str(chnk) for chnk in args if chnk is not None]) + ext +r_reserved_keywords = [ + "if", "else", "repeat", "while", "function", + "for", "in", "next", "break", "TRUE", "FALSE", + "NULL", "Inf", "NaN", "NA", "NA_integer_", + "NA_real_", "NA_complex_", "NA_character_"] + + def get_r_safename(name: str) -> str: """Converts a string to a syntactically valid R name - This can be used to match names along axes of DataFrames produced by R, - assuming that the R script takes no measures to preserve names itself. + This can be used as the Python equivalent of R's make.names() function. https://stat.ethz.ch/R-manual/R-devel/library/base/html/make.names.html """ + # If the name starts with a non-letter character or a dot + # followed by a number, the character "X" is prepended leading_char = lambda x: re.sub(r"^(?=[^a-zA-Z.]+|\.\d)", "X", x) - special_char = lambda x: re.sub(r"[^a-zA-Z0-9_.]", ".", x) - return special_char(leading_char(name)) + + # If the name contains characters that aren't (locale based) letters, + # numbers, dot, or underscore, the characters are replaced with a dot + special_char = lambda x: re.sub(r"[^\w.]", ".", x) + + # If the name contains R keywords, a dot is appended to the keyword + reserved = "|".join(r_reserved_keywords) + reserved_wrd = lambda x: re.sub(fr"^({reserved})$", r'\1.', x) + + return reserved_wrd(special_char(leading_char(name))) class ReadOnlyDict(dict):