From 4770d59ca2fc744117f93097a2d8e9f336a1cd43 Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 3 Apr 2023 18:45:39 -0500 Subject: [PATCH 1/2] Add more tests, handle more error cases for def gathering --- hed/models/df_util.py | 30 +++++++++++++++++++++++++++--- tests/models/test_df_util.py | 17 +++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/hed/models/df_util.py b/hed/models/df_util.py index 3027d1174..2c87b3f71 100644 --- a/hed/models/df_util.py +++ b/hed/models/df_util.py @@ -139,6 +139,16 @@ def _expand_defs(hed_string, hed_schema, def_dict): from hed import HedString return str(HedString(hed_string, hed_schema, def_dict).expand_defs()) +def _get_matching_value(tags): + # Filter out values equal to "#" and get unique values + unique_values = set(tag.extension for tag in tags if tag.extension != "#") + if len(unique_values) == 0: + return "#" + + if len(unique_values) > 1: + return None + + return next(iter(unique_values)) def process_def_expands(hed_strings, hed_schema, known_defs=None, ambiguous_defs=None): """ @@ -181,6 +191,11 @@ def process_def_expands(hed_strings, hed_schema, known_defs=None, ambiguous_defs errors.setdefault(def_tag_name.lower(), []).append(def_group) continue + # This is a def we recognized earlier as an error AND it wasn't a known definition. + if def_tag_name.lower() in errors: + errors.setdefault(def_tag_name.lower(), []).append(def_group) + continue + has_extension = "/" in def_tag.extension # If there's no extension, this is fine. @@ -206,13 +221,22 @@ def process_def_expands(hed_strings, hed_schema, known_defs=None, ambiguous_defs if len(these_defs) >= 1: all_tags_list = [group.get_all_tags() for group in these_defs] for tags in zip(*all_tags_list): - value_per_tag.append(next((tag.extension for tag in tags if tag.extension != "#"), None)) - ambiguous_values = value_per_tag.count(None) + matching_val = _get_matching_value(tags) + value_per_tag.append(matching_val) + + if value_per_tag.count(None): + groups = ambiguous_defs.get(def_tag_name.lower(), []) + for group in groups: + errors.setdefault(def_tag_name.lower(), []).append(group) + + del ambiguous_defs[def_tag_name.lower()] + continue + ambiguous_values = value_per_tag.count("#") if ambiguous_values == 1: new_contents = group_tag.copy() for tag, value in zip(new_contents.get_all_tags(), value_per_tag): if value is not None: - tag.extension = f"/{value}" + tag.extension = f"{value}" def_dict.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=new_contents, takes_value=True, source_context=[]) diff --git a/tests/models/test_df_util.py b/tests/models/test_df_util.py index 5bcdfe704..ff58e35b2 100644 --- a/tests/models/test_df_util.py +++ b/tests/models/test_df_util.py @@ -200,13 +200,26 @@ def test_ambiguous_defs(self): _, ambiguous_defs, _ = process_def_expands(test_strings, self.schema) self.assertEqual(len(ambiguous_defs), 5) + def test_ambiguous_conflicting_defs(self): + # This is invalid due to conflicting defs + test_strings = [ + "(Def-expand/A1/2, (Action/2, Age/5, Item-count/2))", + "(Def-expand/A1/3, (Action/3, Age/4, Item-count/3))", + + # This could be identified, but fails due to the above raising errors + "(Def-expand/A1/4, (Action/4, Age/5, Item-count/2))", + ] + defs, ambiguous, errors = process_def_expands(test_strings, self.schema) + self.assertEqual(len(defs), 0) + self.assertEqual(len(ambiguous), 0) + self.assertEqual(len(errors["a1"]), 3) + def test_errors(self): - # Cases where you can only retroactively identify the first def-expand + # Basic recognition of conflicting errors test_strings = [ "(Def-expand/A1/1, (Action/1, Age/5, Item-count/2))", "(Def-expand/A1/2, (Action/2, Age/5, Item-count/2))", "(Def-expand/A1/3, (Action/3, Age/5, Item-count/3))", - ] _, _, errors = process_def_expands(test_strings, self.schema) self.assertEqual(len(errors), 1) From 9a3e72585de75c22644ca1b50344e44febe84fd6 Mon Sep 17 00:00:00 2001 From: IanCa Date: Mon, 3 Apr 2023 20:50:36 -0500 Subject: [PATCH 2/2] Refactor into class --- hed/models/def_expand_gather.py | 182 ++++++++++++++++++++++++++++++++ hed/models/df_util.py | 84 +-------------- 2 files changed, 186 insertions(+), 80 deletions(-) create mode 100644 hed/models/def_expand_gather.py diff --git a/hed/models/def_expand_gather.py b/hed/models/def_expand_gather.py new file mode 100644 index 000000000..5d2b5c935 --- /dev/null +++ b/hed/models/def_expand_gather.py @@ -0,0 +1,182 @@ +import pandas as pd +from hed.models import DefinitionDict, DefinitionEntry, HedString + + +class DefExpandGatherer: + """Class for gathering definitions from a series of def-expands, including possibly ambiguous ones""" + def __init__(self, hed_schema, known_defs=None, ambiguous_defs=None, errors=None): + """Initialize the DefExpandGatherer class. + + Parameters: + hed_schema (HedSchema): The HED schema to be used for processing. + known_defs (dict, optional): A dictionary of known definitions. + ambiguous_defs (dict, optional): A dictionary of ambiguous def-expand definitions. + + """ + self.hed_schema = hed_schema + self.known_defs = known_defs if known_defs else {} + self.ambiguous_defs = ambiguous_defs if ambiguous_defs else {} + self.errors = errors if errors else {} + self.def_dict = DefinitionDict(self.known_defs, self.hed_schema) + + def process_def_expands(self, hed_strings, known_defs=None): + """Process the HED strings containing def-expand tags. + + Parameters: + hed_strings (pd.Series or list): A Pandas Series or list of HED strings to be processed. + known_defs (dict, optional): A dictionary of known definitions to be added. + + Returns: + tuple: A tuple containing the DefinitionDict, ambiguous definitions, and errors. + """ + if not isinstance(hed_strings, pd.Series): + hed_strings = pd.Series(hed_strings) + + def_expand_mask = hed_strings.str.contains('Def-Expand/', case=False) + + if known_defs: + self.def_dict.add_definitions(known_defs, self.hed_schema) + for i in hed_strings[def_expand_mask].index: + string = hed_strings.loc[i] + self._process_hed_string(string) + + return self.def_dict, self.ambiguous_defs, self.errors + + def _process_hed_string(self, string): + """Process a single HED string to extract definitions and handle known and ambiguous definitions. + + Parameters: + string (str): The HED string to be processed. + """ + hed_str = HedString(string, self.hed_schema) + + for def_tag, def_expand_group, def_group in hed_str.find_def_tags(recursive=True): + if def_tag == def_expand_group: + continue + + if not self._handle_known_definition(def_tag, def_expand_group, def_group): + self._handle_ambiguous_definition(def_tag, def_expand_group) + + def _handle_known_definition(self, def_tag, def_expand_group, def_group): + """Handle known def-expand tag in a HED string. + + Parameters: + def_tag (HedTag): The def-expand tag. + def_expand_group (HedGroup): The group containing the def-expand tag. + def_group (HedGroup): The group containing the def-expand group. + + Returns: + bool: True if the def-expand tag is known and handled, False otherwise. + """ + def_tag_name = def_tag.extension.split('/')[0] + def_group_contents = self.def_dict._get_definition_contents(def_tag) + def_expand_group.sort() + + if def_group_contents: + if def_group_contents != def_expand_group: + self.errors.setdefault(def_tag_name.lower(), []).append(def_group) + return True + + if def_tag_name.lower() in self.errors: + self.errors.setdefault(def_tag_name.lower(), []).append(def_group) + return True + + return False + + def _handle_ambiguous_definition(self, def_tag, def_expand_group): + """Handle ambiguous def-expand tag in a HED string. + + Parameters: + def_tag (HedTag): The def-expand tag. + def_expand_group (HedGroup): The group containing the def-expand tag. + """ + def_tag_name = def_tag.extension.split('/')[0] + + has_extension = "/" in def_tag.extension + + if not has_extension: + group_tag = def_expand_group.get_first_group() + self.def_dict.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=group_tag, + takes_value=False, + source_context=[]) + else: + self._process_ambiguous_extension(def_tag, def_expand_group) + + def _process_ambiguous_extension(self, def_tag, def_expand_group): + """Process ambiguous extensions in a def-expand HED string. + + Parameters: + def_tag (HedTag): The def-expand tag. + def_expand_group (HedGroup): The group containing the def-expand tag. + """ + def_tag_name = def_tag.extension.split('/')[0] + def_extension = def_tag.extension.split('/')[-1] + + matching_tags = [tag for tag in def_expand_group.get_all_tags() if + tag.extension == def_extension and tag != def_tag] + + for tag in matching_tags: + tag.extension = "#" + + group_tag = def_expand_group.get_first_group() + + these_defs = self.ambiguous_defs.setdefault(def_tag_name.lower(), []) + these_defs.append(group_tag) + + value_per_tag = [] + if len(these_defs) >= 1: + all_tags_list = [group.get_all_tags() for group in these_defs] + for tags in zip(*all_tags_list): + matching_val = self._get_matching_value(tags) + value_per_tag.append(matching_val) + + self._handle_value_per_tag(def_tag_name, value_per_tag, group_tag) + + def _handle_value_per_tag(self, def_tag_name, value_per_tag, group_tag): + """Handle the values per tag in ambiguous def-expand tag. + + Parameters: + def_tag_name (str): The name of the def-expand tag. + value_per_tag (list): The list of values per HedTag. + group_tag (HedGroup): The def expand contents + """ + if value_per_tag.count(None): + groups = self.ambiguous_defs.get(def_tag_name.lower(), []) + for group in groups: + self.errors.setdefault(def_tag_name.lower(), []).append(group) + + del self.ambiguous_defs[def_tag_name.lower()] + return + + ambiguous_values = value_per_tag.count("#") + if ambiguous_values == 1: + new_contents = group_tag.copy() + for tag, value in zip(new_contents.get_all_tags(), value_per_tag): + if value is not None: + tag.extension = f"{value}" + self.def_dict.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=new_contents, + takes_value=True, + source_context=[]) + + del self.ambiguous_defs[def_tag_name.lower()] + + @staticmethod + def _get_matching_value(tags): + """Get the matching value for a set of HedTag extensions. + + Parameters: + tags (iterator): The list of HedTags to find a matching value for. + + Returns: + str or None: The matching value if found, None otherwise. + """ + extensions = [tag.extension for tag in tags] + unique_extensions = set(extensions) + + if len(unique_extensions) == 1: + return unique_extensions.pop() + elif "#" in unique_extensions: + unique_extensions.remove("#") + if len(unique_extensions) == 1: + return unique_extensions.pop() + return None diff --git a/hed/models/df_util.py b/hed/models/df_util.py index 2c87b3f71..8d00d6770 100644 --- a/hed/models/df_util.py +++ b/hed/models/df_util.py @@ -150,6 +150,7 @@ def _get_matching_value(tags): return next(iter(unique_values)) + def process_def_expands(hed_strings, hed_schema, known_defs=None, ambiguous_defs=None): """ Processes a list of HED strings according to a given HED schema, using known definitions and ambiguous definitions. @@ -163,83 +164,6 @@ def process_def_expands(hed_strings, hed_schema, known_defs=None, ambiguous_defs ambiguous_defs (dict): A dictionary containing ambiguous definitions format TBD. Currently def name key: list of lists of hed tags values """ - if not isinstance(hed_strings, pd.Series): - hed_strings = pd.Series(hed_strings) - - if ambiguous_defs is None: - ambiguous_defs = {} - errors = {} - def_dict = DefinitionDict(known_defs) - - def_expand_mask = hed_strings.str.contains('Def-Expand/', case=False) - - # Iterate over the strings that contain def-expand tags - for i in hed_strings[def_expand_mask].index: - string = hed_strings.loc[i] - hed_str = HedString(string, hed_schema) - - for def_tag, def_expand_group, def_group in hed_str.find_def_tags(recursive=True): - if def_tag == def_expand_group: - continue - - def_tag_name = def_tag.extension.split('/')[0] - # First check for known definitions. If this is known, it's done either way. - def_group_contents = def_dict._get_definition_contents(def_tag) - def_expand_group.sort() - if def_group_contents: - if def_group_contents != def_expand_group: - errors.setdefault(def_tag_name.lower(), []).append(def_group) - continue - - # This is a def we recognized earlier as an error AND it wasn't a known definition. - if def_tag_name.lower() in errors: - errors.setdefault(def_tag_name.lower(), []).append(def_group) - continue - - has_extension = "/" in def_tag.extension - - # If there's no extension, this is fine. - if not has_extension: - group_tag = def_expand_group.get_first_group() - def_dict.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=group_tag, - takes_value=False, - source_context=[]) - else: - def_extension = def_tag.extension.split('/')[-1] - # Find any other tags in def_group.get_all_tags() with tags with the same extension - matching_tags = [tag for tag in def_expand_group.get_all_tags() if tag.extension == def_extension and tag != def_tag] - - for tag in matching_tags: - tag.extension = "#" - - group_tag = def_expand_group.get_first_group() - - these_defs = ambiguous_defs.setdefault(def_tag_name.lower(), []) - these_defs.append(group_tag) - - value_per_tag = [] - if len(these_defs) >= 1: - all_tags_list = [group.get_all_tags() for group in these_defs] - for tags in zip(*all_tags_list): - matching_val = _get_matching_value(tags) - value_per_tag.append(matching_val) - - if value_per_tag.count(None): - groups = ambiguous_defs.get(def_tag_name.lower(), []) - for group in groups: - errors.setdefault(def_tag_name.lower(), []).append(group) - - del ambiguous_defs[def_tag_name.lower()] - continue - ambiguous_values = value_per_tag.count("#") - if ambiguous_values == 1: - new_contents = group_tag.copy() - for tag, value in zip(new_contents.get_all_tags(), value_per_tag): - if value is not None: - tag.extension = f"{value}" - def_dict.defs[def_tag_name.lower()] = DefinitionEntry(name=def_tag_name, contents=new_contents, - takes_value=True, - source_context=[]) - del ambiguous_defs[def_tag_name.lower()] - - return def_dict, ambiguous_defs, errors + from hed.models.def_expand_gather import DefExpandGatherer + def_gatherer = DefExpandGatherer(hed_schema, known_defs, ambiguous_defs) + return def_gatherer.process_def_expands(hed_strings) \ No newline at end of file