From 6eda14aebfecd5657237915a4ddef8beb295092c Mon Sep 17 00:00:00 2001 From: IanCa Date: Wed, 12 Jul 2023 18:43:41 -0500 Subject: [PATCH 1/2] add util functions for hed strings --- hed/models/string_util.py | 59 ++++++++++++++ tests/models/test_string_util.py | 135 +++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 hed/models/string_util.py create mode 100644 tests/models/test_string_util.py diff --git a/hed/models/string_util.py b/hed/models/string_util.py new file mode 100644 index 000000000..0b28aa805 --- /dev/null +++ b/hed/models/string_util.py @@ -0,0 +1,59 @@ +from hed.models.hed_string import HedString + + +def split_base_tags(hed_string, base_tags, remove_group=False): + """ Splits a HedString object into two separate HedString objects based on the presence of base tags. + + Args: + hed_string (HedString): The input HedString object to be split. + base_tags (list of str): A list of strings representing the base tags. + This is matching the base tag NOT all the terms above it. + remove_group (bool, optional): Flag indicating whether to remove the parent group. Defaults to False. + + Returns: + tuple: A tuple containing two HedString objects: + - The first HedString object contains the remaining tags from hed_string. + - The second HedString object contains the tags from hed_string that match the base_tags. + """ + + base_tags = [tag.lower() for tag in base_tags] + include_groups = 0 + if remove_group: + include_groups = 2 + found_things = hed_string.find_tags(base_tags, recursive=True, include_groups=include_groups) + if remove_group: + found_things = [tag if isinstance(group, HedString) else group for tag, group in found_things] + + if found_things: + hed_string.remove(found_things) + + return hed_string, HedString("", hed_string._schema, _contents=found_things) + + +def split_def_tags(hed_string, def_names, remove_group=False): + """ Splits a HedString object into two separate HedString objects based on the presence of wildcard tags. + + This does NOT handle def-expand tags currently. + + Args: + hed_string (HedString): The input HedString object to be split. + def_names (list of str): A list of def names to search for. Can optionally include a value. + remove_group (bool, optional): Flag indicating whether to remove the parent group. Defaults to False. + + Returns: + tuple: A tuple containing two HedString objects: + - The first HedString object contains the remaining tags from hed_string. + - The second HedString object contains the tags from hed_string that match the base_tags. + """ + include_groups = 0 + if remove_group: + include_groups = 2 + wildcard_tags = [f"def/{def_name}".lower() for def_name in def_names] + found_things = hed_string.find_wildcard_tags(wildcard_tags, recursive=True, include_groups=include_groups) + if remove_group: + found_things = [tag if isinstance(group, HedString) else group for tag, group in found_things] + + if found_things: + hed_string.remove(found_things) + + return hed_string, HedString("", hed_string._schema, _contents=found_things) diff --git a/tests/models/test_string_util.py b/tests/models/test_string_util.py new file mode 100644 index 000000000..6819d7bca --- /dev/null +++ b/tests/models/test_string_util.py @@ -0,0 +1,135 @@ +import unittest +from hed import HedString, load_schema_version +from hed.models.string_util import split_base_tags, split_def_tags +import copy + + +class TestHedStringSplit(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.schema = load_schema_version() + + def check_split_base_tags(self, hed_string, base_tags, expected_string, expected_string2): + # Test case 1: remove_group=False + hed_string_copy = copy.deepcopy(hed_string) + remaining_hed, found_hed = split_base_tags(hed_string_copy, base_tags, remove_group=False) + + self.assertIsInstance(remaining_hed, HedString) + self.assertIsInstance(found_hed, HedString) + self.assertEqual(str(remaining_hed), expected_string) + + self.assertTrue(all(tag in [str(t) for t in found_hed.get_all_tags()] for tag in base_tags)) + self.assertTrue(all(tag not in [str(t) for t in remaining_hed.get_all_tags()] for tag in base_tags)) + + # Test case 2: remove_group=True + hed_string_copy = copy.deepcopy(hed_string) + remaining_hed, found_hed = split_base_tags(hed_string_copy, base_tags, remove_group=True) + + self.assertIsInstance(remaining_hed, HedString) + self.assertIsInstance(found_hed, HedString) + self.assertEqual(str(remaining_hed), expected_string2) + + self.assertTrue(all(tag in [str(t) for t in found_hed.get_all_tags()] for tag in base_tags)) + self.assertTrue(all(tag not in [str(t) for t in remaining_hed.get_all_tags()] for tag in base_tags)) + + def test_case_1(self): + hed_string = HedString('Memorize,Action,Area', self.schema) + base_tags = ['Area', 'Action'] + expected_string = 'Memorize' + expected_string2 = 'Memorize' + self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2) + + def test_case_2(self): + hed_string = HedString('Area,LightBlue,Handedness', self.schema) + base_tags = ['Area', 'LightBlue'] + expected_string = 'Handedness' + expected_string2 = 'Handedness' + self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2) + + def test_case_3(self): + hed_string = HedString('(Wink,Communicate),Face,HotPink', self.schema) + base_tags = ['Wink', 'Face'] + expected_string = '(Communicate),HotPink' + expected_string2 = "HotPink" + self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2) + + def test_case_4(self): + hed_string = HedString('(Area,(LightBlue,Handedness,(Wink,Communicate))),Face,HotPink', self.schema) + base_tags = ['Area', 'LightBlue'] + expected_string = '((Handedness,(Wink,Communicate))),Face,HotPink' + expected_string2 = 'Face,HotPink' + self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2) + + def test_case_5(self): + hed_string = HedString('(Memorize,(Action,(Area,LightBlue),Handedness),Wink)', self.schema) + base_tags = ['Area', 'LightBlue'] + expected_string = '(Memorize,(Action,Handedness),Wink)' + expected_string2 = '(Memorize,(Action,Handedness),Wink)' + self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2) + +class TestHedStringSplitDef(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.schema = load_schema_version() + + def check_split_def_tags(self, hed_string, def_names, expected_string, expected_string2): + # Test case 1: remove_group=False + hed_string_copy = copy.deepcopy(hed_string) + remaining_hed, found_hed = split_def_tags(hed_string_copy, def_names, remove_group=False) + + self.assertIsInstance(remaining_hed, HedString) + self.assertIsInstance(found_hed, HedString) + self.assertEqual(str(remaining_hed), expected_string) + + self.assertTrue(all(tag.short_base_tag == "Def" for tag in found_hed.get_all_tags())) + self.assertTrue(all(tag.short_base_tag != "Def" for tag in remaining_hed.get_all_tags())) + + # Test case 2: remove_group=True + hed_string_copy = copy.deepcopy(hed_string) + remaining_hed, found_hed = split_def_tags(hed_string_copy, def_names, remove_group=True) + + self.assertIsInstance(remaining_hed, HedString) + self.assertIsInstance(found_hed, HedString) + self.assertEqual(str(remaining_hed), expected_string2) + + #self.assertTrue(all(tag.short_base_tag == "Def" for tag in found_hed.get_all_tags())) + self.assertTrue(all(tag.short_base_tag != "Def" for tag in remaining_hed.get_all_tags())) + + def test_case_1(self): + hed_string = HedString('Memorize,Action,def/CustomTag1', self.schema) + def_names = ['CustomTag1'] + expected_string = 'Memorize,Action' + expected_string2 = 'Memorize,Action' + self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2) + + def test_case_2(self): + hed_string = HedString('def/CustomTag1,LightBlue,def/CustomTag2/123', self.schema) + def_names = ['CustomTag1', 'CustomTag2'] + expected_string = 'LightBlue' + expected_string2 = 'LightBlue' + self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2) + + def test_case_3(self): + hed_string = HedString('(def/CustomTag1,Communicate),Face,def/CustomTag3/abc', self.schema) + def_names = ['CustomTag1', 'CustomTag3'] + expected_string = '(Communicate),Face' + expected_string2 = 'Face' + self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2) + + def test_case_4(self): + hed_string = HedString('(def/CustomTag1,(LightBlue,def/CustomTag2/123,(Wink,Communicate))),Face,def/CustomTag3/abc', self.schema) + def_names = ['CustomTag1', 'CustomTag2', 'CustomTag3'] + expected_string = '((LightBlue,(Wink,Communicate))),Face' + expected_string2 = 'Face' + self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2) + + def test_case_5(self): + hed_string = HedString('(Memorize,(Action,(def/CustomTag1,LightBlue),def/CustomTag2/123),Wink)', self.schema) + def_names = ['CustomTag1', 'CustomTag2'] + expected_string = '(Memorize,(Action,(LightBlue)),Wink)' + expected_string2 = '(Memorize,Wink)' + self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 5e067e6e88bf8a6497af9683cf5c3c3920884ec9 Mon Sep 17 00:00:00 2001 From: IanCa Date: Tue, 18 Jul 2023 17:45:16 -0500 Subject: [PATCH 2/2] Add schema comparison functions. Rename AllTags to just Tags. --- hed/models/string_util.py | 2 +- hed/schema/hed_schema.py | 30 +-- hed/schema/hed_schema_base.py | 6 +- hed/schema/hed_schema_constants.py | 2 +- hed/schema/hed_schema_group.py | 12 +- hed/schema/hed_schema_section.py | 2 +- hed/schema/schema_attribute_validators.py | 4 +- hed/schema/schema_compare.py | 202 ++++++++++++++++++ hed/schema/schema_compliance.py | 2 + hed/schema/schema_io/schema2base.py | 8 +- hed/schema/schema_io/schema2xml.py | 2 +- hed/schema/schema_io/wiki2schema.py | 6 +- hed/schema/schema_io/wiki_constants.py | 2 +- hed/schema/schema_io/xml2schema.py | 8 +- hed/schema/schema_io/xml_constants.py | 6 +- hed/schema/schema_validation_util.py | 2 +- hed/tools/util/schema_util.py | 2 +- tests/schema/test_hed_schema.py | 2 +- tests/schema/test_hed_schema_io.py | 4 +- .../test_schema_attribute_validators.py | 12 +- tests/schema/test_schema_compare.py | 105 +++++++++ tests/schema/test_schema_compliance.py | 2 - 22 files changed, 365 insertions(+), 58 deletions(-) create mode 100644 hed/schema/schema_compare.py create mode 100644 tests/schema/test_schema_compare.py diff --git a/hed/models/string_util.py b/hed/models/string_util.py index 0b28aa805..9f6023223 100644 --- a/hed/models/string_util.py +++ b/hed/models/string_util.py @@ -43,7 +43,7 @@ def split_def_tags(hed_string, def_names, remove_group=False): Returns: tuple: A tuple containing two HedString objects: - The first HedString object contains the remaining tags from hed_string. - - The second HedString object contains the tags from hed_string that match the base_tags. + - The second HedString object contains the tags from hed_string that match the def_names. """ include_groups = 0 if remove_group: diff --git a/hed/schema/hed_schema.py b/hed/schema/hed_schema.py index c0f1d2589..7b71b2bb1 100644 --- a/hed/schema/hed_schema.py +++ b/hed/schema/hed_schema.py @@ -86,13 +86,13 @@ def merged(self): return not self.header_attributes.get(constants.UNMERGED_ATTRIBUTE, "") @property - def all_tags(self): + def tags(self): """ Return the tag schema section. Returns: HedSchemaTagSection: The tag section. """ - return self._sections[HedSectionKey.AllTags] + return self._sections[HedSectionKey.Tags] @property def unit_classes(self): @@ -354,7 +354,7 @@ def check_compliance(self, check_for_warnings=True, name=None, error_handler=Non from hed.schema import schema_compliance return schema_compliance.check_compliance(self, check_for_warnings, name, error_handler) - def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags): + def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags): """ Return tag entries with the given attribute. Parameters: @@ -370,7 +370,7 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags): return self._sections[key_class].get_entries_with_attribute(attribute, return_name_only=True, schema_namespace=self._namespace) - def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""): + def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""): """ Return the schema entry for this tag, if one exists. Parameters: @@ -378,12 +378,12 @@ def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace= This will not handle extensions or similar. If this is a tag, it can have a schema namespace, but it's not required key_class (HedSectionKey or str): The type of entry to return. - schema_namespace (str): Only used on AllTags. If incorrect, will return None. + schema_namespace (str): Only used on Tags. If incorrect, will return None. Returns: HedSchemaEntry: The schema entry for the given tag. """ - if key_class == HedSectionKey.AllTags: + if key_class == HedSectionKey.Tags: if schema_namespace != self._namespace: return None if name.startswith(self._namespace): @@ -415,7 +415,7 @@ def find_tag_entry(self, tag, schema_namespace=""): # =============================================== # Private utility functions for getting/finding tags # =============================================== - def _get_tag_entry(self, name, key_class=HedSectionKey.AllTags): + def _get_tag_entry(self, name, key_class=HedSectionKey.Tags): """ Return the schema entry for this tag, if one exists. Parameters: @@ -524,7 +524,7 @@ def _validate_remaining_terms(self, tag, working_tag, prefix_tag_adj, current_sl tag, index_in_tag=word_start_index, index_in_tag_end=word_start_index + len(name), - expected_parent_tag=self.all_tags[name].name) + expected_parent_tag=self.tags[name].name) raise self._TagIdentifyError(error) word_start_index += len(name) + 1 @@ -533,7 +533,7 @@ def _validate_remaining_terms(self, tag, working_tag, prefix_tag_adj, current_sl # =============================================== def finalize_dictionaries(self): """ Call to finish loading. """ - self._has_duplicate_tags = bool(self.all_tags.duplicate_names) + self._has_duplicate_tags = bool(self.tags.duplicate_names) self._update_all_entries() def _update_all_entries(self): @@ -568,13 +568,13 @@ def get_desc_iter(self): if tag_entry.description: yield tag_entry.name, tag_entry.description - def get_tag_description(self, tag_name, key_class=HedSectionKey.AllTags): + def get_tag_description(self, tag_name, key_class=HedSectionKey.Tags): """ Return the description associated with the tag. Parameters: tag_name (str): A hed tag name(or unit/unit modifier etc) with proper capitalization. key_class (str): A string indicating type of description (e.g. All tags, Units, Unit modifier). - The default is HedSectionKey.AllTags. + The default is HedSectionKey.Tags. Returns: str: A description of the specified tag. @@ -595,7 +595,7 @@ def get_all_schema_tags(self, return_last_term=False): """ final_list = [] - for lower_tag, tag_entry in self.all_tags.items(): + for lower_tag, tag_entry in self.tags.items(): if return_last_term: final_list.append(tag_entry.name.split('/')[-1]) else: @@ -636,7 +636,7 @@ def get_tag_attribute_names(self): and not tag_entry.has_attribute(HedKey.UnitModifierProperty) and not tag_entry.has_attribute(HedKey.ValueClassProperty)} - def get_all_tag_attributes(self, tag_name, key_class=HedSectionKey.AllTags): + def get_all_tag_attributes(self, tag_name, key_class=HedSectionKey.Tags): """ Gather all attributes for a given tag name. Parameters: @@ -670,7 +670,7 @@ def _create_empty_sections(): dictionaries[HedSectionKey.Units] = HedSchemaSection(HedSectionKey.Units) dictionaries[HedSectionKey.UnitClasses] = HedSchemaUnitClassSection(HedSectionKey.UnitClasses) dictionaries[HedSectionKey.ValueClasses] = HedSchemaSection(HedSectionKey.ValueClasses) - dictionaries[HedSectionKey.AllTags] = HedSchemaTagSection(HedSectionKey.AllTags, case_sensitive=False) + dictionaries[HedSectionKey.Tags] = HedSchemaTagSection(HedSectionKey.Tags, case_sensitive=False) return dictionaries @@ -717,7 +717,7 @@ def _get_attributes_for_section(self, key_class): dict or HedSchemaSection: A dict of all the attributes and this section. """ - if key_class == HedSectionKey.AllTags: + if key_class == HedSectionKey.Tags: return self.get_tag_attribute_names() elif key_class == HedSectionKey.Attributes: prop_added_dict = {key: value for key, value in self._sections[HedSectionKey.Properties].items()} diff --git a/hed/schema/hed_schema_base.py b/hed/schema/hed_schema_base.py index b0e29ebcc..6651077e0 100644 --- a/hed/schema/hed_schema_base.py +++ b/hed/schema/hed_schema_base.py @@ -55,7 +55,7 @@ def valid_prefixes(self): raise NotImplemented("This function must be implemented in the baseclass") @abstractmethod - def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags): + def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags): """ Return tag entries with the given attribute. Parameters: @@ -72,7 +72,7 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags): # todo: maybe tweak this API so you don't have to pass in library namespace? @abstractmethod - def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""): + def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""): """ Return the schema entry for this tag, if one exists. Parameters: @@ -80,7 +80,7 @@ def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace= This will not handle extensions or similar. If this is a tag, it can have a schema namespace, but it's not required key_class (HedSectionKey or str): The type of entry to return. - schema_namespace (str): Only used on AllTags. If incorrect, will return None. + schema_namespace (str): Only used on Tags. If incorrect, will return None. Returns: HedSchemaEntry: The schema entry for the given tag. diff --git a/hed/schema/hed_schema_constants.py b/hed/schema/hed_schema_constants.py index 03a783fd4..e74f1b70f 100644 --- a/hed/schema/hed_schema_constants.py +++ b/hed/schema/hed_schema_constants.py @@ -4,7 +4,7 @@ class HedSectionKey(Enum): """ Kegs designating specific sections in a HedSchema object. """ # overarching category listing all tags - AllTags = 'tags' + Tags = 'tags' # Overarching category listing all unit classes UnitClasses = 'unitClasses' # Overarching category listing all units(not divided by type) diff --git a/hed/schema/hed_schema_group.py b/hed/schema/hed_schema_group.py index 96187b73f..ae0ac2b81 100644 --- a/hed/schema/hed_schema_group.py +++ b/hed/schema/hed_schema_group.py @@ -101,7 +101,7 @@ def check_compliance(self, check_for_warnings=True, name=None, error_handler=Non issues_list += schema.check_compliance(check_for_warnings, name, error_handler) return issues_list - def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags): + def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags): """ Return tag entries with the given attribute. Parameters: @@ -114,12 +114,12 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags): Notes: - The result is cached so will be fast after first call. """ - all_tags = set() + tags = set() for schema in self._schemas.values(): - all_tags.update(schema.get_tags_with_attribute(attribute, key_class)) - return list(all_tags) + tags.update(schema.get_tags_with_attribute(attribute, key_class)) + return list(tags) - def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""): + def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""): """ Return the schema entry for this tag, if one exists. Parameters: @@ -127,7 +127,7 @@ def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace= This will not handle extensions or similar. If this is a tag, it can have a schema namespace, but it's not required key_class (HedSectionKey or str): The type of entry to return. - schema_namespace (str): Only used on AllTags. If incorrect, will return None. + schema_namespace (str): Only used on Tags. If incorrect, will return None. Returns: HedSchemaEntry: The schema entry for the given tag. diff --git a/hed/schema/hed_schema_section.py b/hed/schema/hed_schema_section.py index 2da2ff54b..9b6ba657e 100644 --- a/hed/schema/hed_schema_section.py +++ b/hed/schema/hed_schema_section.py @@ -9,7 +9,7 @@ HedSectionKey.Units: UnitEntry, HedSectionKey.UnitClasses: UnitClassEntry, HedSectionKey.ValueClasses: HedSchemaEntry, - HedSectionKey.AllTags: HedTagEntry, + HedSectionKey.Tags: HedTagEntry, } diff --git a/hed/schema/schema_attribute_validators.py b/hed/schema/schema_attribute_validators.py index 2fa23d1db..47dc7410b 100644 --- a/hed/schema/schema_attribute_validators.py +++ b/hed/schema/schema_attribute_validators.py @@ -50,7 +50,7 @@ def tag_exists_check(hed_schema, tag_entry, attribute_name): possible_tags = tag_entry.attributes.get(attribute_name, "") split_tags = possible_tags.split(",") for org_tag in split_tags: - if org_tag and org_tag not in hed_schema.all_tags: + if org_tag and org_tag not in hed_schema.tags: issues += ErrorHandler.format_error(ValidationErrors.NO_VALID_TAG_FOUND, org_tag, index_in_tag=0, @@ -72,7 +72,7 @@ def tag_exists_base_schema_check(hed_schema, tag_entry, attribute_name): """ issues = [] rooted_tag = tag_entry.attributes.get(attribute_name, "") - if rooted_tag and rooted_tag not in hed_schema.all_tags: + if rooted_tag and rooted_tag not in hed_schema.tags: issues += ErrorHandler.format_error(ValidationErrors.NO_VALID_TAG_FOUND, rooted_tag, index_in_tag=0, diff --git a/hed/schema/schema_compare.py b/hed/schema/schema_compare.py new file mode 100644 index 000000000..49d5c72b6 --- /dev/null +++ b/hed/schema/schema_compare.py @@ -0,0 +1,202 @@ +from hed.schema.hed_schema import HedSchema, HedKey +from hed.schema.hed_schema_constants import HedSectionKey + + +def find_matching_tags(schema1, schema2, return_string=False): + """ + Compare the tags in two library schemas. This finds tags with the same term. + + Parameters: + schema1 (HedSchema): The first schema to be compared. + schema2 (HedSchema): The second schema to be compared. + return_string (bool): Return this as a string if true + + Returns: + dict or str: A dictionary containing matching entries in the Tags section of both schemas. + """ + matches, _, _, unequal_entries = compare_schemas(schema1, schema2) + + for section_key, section_dict in matches.items(): + section_dict.update(unequal_entries[section_key]) + + if return_string: + return "\n".join([pretty_print_diff_all(entries, prompt="Found matching node ") for entries in matches.values()]) + return matches + + +def compare_differences(schema1, schema2, return_string=False, attribute_filter=None): + """ + Compare the tags in two schemas, this finds any differences + + Parameters: + schema1 (HedSchema): The first schema to be compared. + schema2 (HedSchema): The second schema to be compared. + return_string (bool): Return this as a string if true + attribute_filter (str, optional): The attribute to filter entries by. + Entries without this attribute are skipped. + The most common use would be HedKey.InLibrary + If it evaluates to False, no filtering is performed. + Returns: + tuple or str: A tuple containing three dictionaries: + - not_in_schema1(dict): Entries present in schema2 but not in schema1. + - not_in_schema2(dict): Entries present in schema1 but not in schema2. + - unequal_entries(dict): Entries that differ between the two schemas. + - or a formatted string of the differences + """ + _, not_in_1, not_in_2, unequal_entries = compare_schemas(schema1, schema2, attribute_filter=attribute_filter) + + if return_string: + str1 = "\n".join([pretty_print_diff_all(entries) for entries in unequal_entries.values()]) + "\n" + str2 = "\n".join([pretty_print_missing_all(entries, "Schema1") for entries in not_in_1.values()]) + "\n" + str3 = "\n".join([pretty_print_missing_all(entries, "Schema2") for entries in not_in_2.values()]) + return str1 + str2 + str3 + return not_in_1, not_in_2, unequal_entries + + +def compare_schemas(schema1, schema2, attribute_filter=HedKey.InLibrary, sections=(HedSectionKey.Tags,)): + """ + Compare two schemas section by section. + The function records matching entries, entries present in one schema but not in the other, and unequal entries. + + Parameters: + schema1 (HedSchema): The first schema to be compared. + schema2 (HedSchema): The second schema to be compared. + attribute_filter (str, optional): The attribute to filter entries by. + Entries without this attribute are skipped. + If it evaluates to False, no filtering is performed. + sections(list): the list of sections to compare. By default, just the tags section. + + Returns: + tuple: A tuple containing four dictionaries: + - matches(dict): Entries present in both schemas and are equal. + - not_in_schema1(dict): Entries present in schema2 but not in schema1. + - not_in_schema2(dict): Entries present in schema1 but not in schema2. + - unequal_entries(dict): Entries present in both schemas but are not equal. + """ + # Result dictionaries to hold matches, keys not in schema2, keys not in schema1, and unequal entries + matches = {} + not_in_schema2 = {} + not_in_schema1 = {} + unequal_entries = {} + + # Iterate over keys in HedSectionKey + for section_key in HedSectionKey: + if not sections or section_key not in sections: + continue + # Dictionaries to record (short_tag_name or name): entry pairs + dict1 = {} + dict2 = {} + + section1 = schema1[section_key] + section2 = schema2[section_key] + + attribute = 'short_tag_name' if section_key == HedSectionKey.Tags else 'name' + + for entry in section1.all_entries: + if not attribute_filter or entry.has_attribute(attribute_filter): + dict1[getattr(entry, attribute)] = entry + + for entry in section2.all_entries: + if not attribute_filter or entry.has_attribute(attribute_filter): + dict2[getattr(entry, attribute)] = entry + + # Find keys present in dict1 but not in dict2, and vice versa + not_in_schema2[section_key] = {key: dict1[key] for key in dict1 if key not in dict2} + not_in_schema1[section_key] = {key: dict2[key] for key in dict2 if key not in dict1} + + # Find keys present in both but with unequal entries + unequal_entries[section_key] = {key: (dict1[key], dict2[key]) for key in dict1 + if key in dict2 and dict1[key] != dict2[key]} + + # Find matches + matches[section_key] = {key: (dict1[key], dict2[key]) for key in dict1 + if key in dict2 and dict1[key] == dict2[key]} + + return matches, not_in_schema1, not_in_schema2, unequal_entries + + +def pretty_print_diff_entry(entry1, entry2): + """ + Returns the differences between two HedSchemaEntry objects as a list of strings + + Parameters: + entry1 (HedSchemaEntry): The first entry. + entry2 (HedSchemaEntry): The second entry. + + Returns: + diff_lines(list): the differences as a list of strings + """ + output = [] + # Checking if both entries have the same name + if entry1.name != entry2.name: + output.append(f"\tName differs: '{entry1.name}' vs '{entry2.name}'") + + # Checking if both entries have the same description + if entry1.description != entry2.description: + output.append(f"\tDescription differs: '{entry1.description}' vs '{entry2.description}'") + + # Comparing attributes + for attr in set(entry1.attributes.keys()).union(entry2.attributes.keys()): + if entry1.attributes.get(attr) != entry2.attributes.get(attr): + output.append(f"\tAttribute '{attr}' differs: '{entry1.attributes.get(attr)}' vs '{entry2.attributes.get(attr)}'") + + return output + + +def pretty_print_entry(entry): + """ Returns the contents of a HedSchemaEntry object as a list of strings. + + Parameters: + entry (HedSchemaEntry): The HedSchemaEntry object to be displayed. + + Returns: + List of strings representing the entry. + """ + # Initialize the list with the name of the entry + output = [f"\tName: {entry.name}"] + + # Add the description to the list if it exists + if entry.description is not None: + output.append(f"\tDescription: {entry.description}") + + # Iterate over all attributes and add them to the list + for attr_key, attr_value in entry.attributes.items(): + output.append(f"\tAttribute: {attr_key} - Value: {attr_value}") + + return output + + +def pretty_print_diff_all(entries, prompt="Differences for "): + """ + Formats the differences between pairs of HedSchemaEntry objects. + + Parameters: + entries (dict): A dictionary where each key maps to a pair of HedSchemaEntry objects. + prompt(str): The prompt for each entry + Returns: + diff_string(str): The differences found in the dict + """ + output = [] + for key, (entry1, entry2) in entries.items(): + output.append(f"{prompt}'{key}':") + output += pretty_print_diff_entry(entry1, entry2) + + return "\n".join(output) + + +def pretty_print_missing_all(entries, schema_name): + """ + Formats the missing entries from schema_name. + + Parameters: + entries (dict): A dictionary where each key maps to a pair of HedSchemaEntry objects. + schema_name(str): The name these entries are missing from + Returns: + diff_string(str): The differences found in the dict + """ + output = [] + for key, entry in entries.items(): + output.append(f"'{key}' not in {schema_name}':") + output += pretty_print_entry(entry) + + return "\n".join(output) \ No newline at end of file diff --git a/hed/schema/schema_compliance.py b/hed/schema/schema_compliance.py index 20db73376..59770e6b2 100644 --- a/hed/schema/schema_compliance.py +++ b/hed/schema/schema_compliance.py @@ -114,3 +114,5 @@ def check_invalid_chars(self): for tag_name, desc in self.hed_schema.get_desc_iter(): issues_list += validate_schema_description(tag_name, desc) return issues_list + + diff --git a/hed/schema/schema_io/schema2base.py b/hed/schema/schema_io/schema2base.py index 7ed8aceaf..280562fd9 100644 --- a/hed/schema/schema_io/schema2base.py +++ b/hed/schema/schema_io/schema2base.py @@ -41,7 +41,7 @@ def process_schema(self, hed_schema, save_merged=False): self._output_header(hed_schema.get_save_header_attributes(self._save_merged), hed_schema.prologue) - self._output_tags(hed_schema.all_tags) + self._output_tags(hed_schema.tags) self._output_units(hed_schema.unit_classes) self._output_section(hed_schema, HedSectionKey.UnitModifiers) self._output_section(hed_schema, HedSectionKey.ValueClasses) @@ -69,13 +69,13 @@ def _write_tag_entry(self, tag_entry, parent=None, level=0): def _write_entry(self, entry, parent_node, include_props=True): raise NotImplementedError("This needs to be defined in the subclass") - def _output_tags(self, all_tags): - schema_node = self._start_section(HedSectionKey.AllTags) + def _output_tags(self, tags): + schema_node = self._start_section(HedSectionKey.Tags) # This assumes .all_entries is sorted in a reasonable way for output. level_adj = 0 all_nodes = {} # List of all nodes we've written out. - for tag_entry in all_tags.all_entries: + for tag_entry in tags.all_entries: if self._should_skip(tag_entry): continue tag = tag_entry.name diff --git a/hed/schema/schema_io/schema2xml.py b/hed/schema/schema_io/schema2xml.py index 974480347..c25e38074 100644 --- a/hed/schema/schema_io/schema2xml.py +++ b/hed/schema/schema_io/schema2xml.py @@ -52,7 +52,7 @@ def _write_tag_entry(self, tag_entry, parent_node=None, level=0): SubElement The added node """ - key_class = HedSectionKey.AllTags + key_class = HedSectionKey.Tags tag_element = xml_constants.ELEMENT_NAMES[key_class] tag_description = tag_entry.description tag_attributes = tag_entry.attributes diff --git a/hed/schema/schema_io/wiki2schema.py b/hed/schema/schema_io/wiki2schema.py index a66dc7b2e..1b91da9d1 100644 --- a/hed/schema/schema_io/wiki2schema.py +++ b/hed/schema/schema_io/wiki2schema.py @@ -290,7 +290,7 @@ def _read_schema(self, lines): lines: [(int, str)] Lines for this section """ - self._schema._initialize_attributes(HedSectionKey.AllTags) + self._schema._initialize_attributes(HedSectionKey.Tags) parent_tags = [] level_adj = 0 for line_number, line in lines: @@ -322,7 +322,7 @@ def _read_schema(self, lines): self._add_fatal_error(line_number, line, e.message, e.code) continue - tag_entry = self._add_to_dict(line_number, line, tag_entry, HedSectionKey.AllTags) + tag_entry = self._add_to_dict(line_number, line, tag_entry, HedSectionKey.Tags) parent_tags.append(tag_entry.short_tag_name) @@ -594,7 +594,7 @@ def _add_tag_line(self, parent_tags, line_number, tag_line): long_tag_name = "/".join(parent_tags) + "/" + tag_name else: long_tag_name = tag_name - return self._create_entry(line_number, tag_line, HedSectionKey.AllTags, long_tag_name) + return self._create_entry(line_number, tag_line, HedSectionKey.Tags, long_tag_name) self._add_fatal_error(line_number, tag_line) return None diff --git a/hed/schema/schema_io/wiki_constants.py b/hed/schema/schema_io/wiki_constants.py index 2f7020654..7d1bf31ad 100644 --- a/hed/schema/schema_io/wiki_constants.py +++ b/hed/schema/schema_io/wiki_constants.py @@ -14,7 +14,7 @@ EPILOGUE_SECTION_ELEMENT = "'''Epilogue'''" wiki_section_headers = { - HedSectionKey.AllTags: START_HED_STRING, + HedSectionKey.Tags: START_HED_STRING, HedSectionKey.UnitClasses: UNIT_CLASS_STRING, HedSectionKey.Units: None, HedSectionKey.UnitModifiers: UNIT_MODIFIER_STRING, diff --git a/hed/schema/schema_io/xml2schema.py b/hed/schema/schema_io/xml2schema.py index 6db98eb68..437e11fff 100644 --- a/hed/schema/schema_io/xml2schema.py +++ b/hed/schema/schema_io/xml2schema.py @@ -113,7 +113,7 @@ def _populate_tag_dictionaries(self): A dictionary of dictionaries that has been populated with dictionaries associated with tag attributes. """ - self._schema._initialize_attributes(HedSectionKey.AllTags) + self._schema._initialize_attributes(HedSectionKey.Tags) tag_elements = self._get_elements_by_name("node") loading_from_chain = "" loading_from_chain_short = "" @@ -125,7 +125,7 @@ def _populate_tag_dictionaries(self): loading_from_chain = "" else: tag = tag.replace(loading_from_chain_short, loading_from_chain) - tag_entry = self._parse_node(tag_element, HedSectionKey.AllTags, tag) + tag_entry = self._parse_node(tag_element, HedSectionKey.Tags, tag) rooted_entry = schema_validation_util.find_rooted_entry(tag_entry, self._schema, self._loading_merged) if rooted_entry: @@ -133,9 +133,9 @@ def _populate_tag_dictionaries(self): loading_from_chain_short = tag_entry.short_tag_name tag = tag.replace(loading_from_chain_short, loading_from_chain) - tag_entry = self._parse_node(tag_element, HedSectionKey.AllTags, tag) + tag_entry = self._parse_node(tag_element, HedSectionKey.Tags, tag) - self._add_to_dict(tag_entry, HedSectionKey.AllTags) + self._add_to_dict(tag_entry, HedSectionKey.Tags) def _populate_unit_class_dictionaries(self): """Populates a dictionary of dictionaries associated with all the unit classes, unit class units, and unit diff --git a/hed/schema/schema_io/xml_constants.py b/hed/schema/schema_io/xml_constants.py index 3dbd7e647..7c2f6071d 100644 --- a/hed/schema/schema_io/xml_constants.py +++ b/hed/schema/schema_io/xml_constants.py @@ -36,7 +36,7 @@ SECTION_NAMES = { - HedSectionKey.AllTags: SCHEMA_ELEMENT, + HedSectionKey.Tags: SCHEMA_ELEMENT, HedSectionKey.UnitClasses: UNIT_CLASS_SECTION_ELEMENT, HedSectionKey.UnitModifiers: UNIT_MODIFIER_SECTION_ELEMENT, HedSectionKey.ValueClasses: SCHEMA_VALUE_CLASSES_SECTION_ELEMENT, @@ -46,7 +46,7 @@ ELEMENT_NAMES = { - HedSectionKey.AllTags: TAG_DEF_ELEMENT, + HedSectionKey.Tags: TAG_DEF_ELEMENT, HedSectionKey.UnitClasses: UNIT_CLASS_DEF_ELEMENT, HedSectionKey.Units: UNIT_CLASS_UNIT_ELEMENT, HedSectionKey.UnitModifiers: UNIT_MODIFIER_DEF_ELEMENT, @@ -57,7 +57,7 @@ ATTRIBUTE_PROPERTY_ELEMENTS = { - HedSectionKey.AllTags: ATTRIBUTE_ELEMENT, + HedSectionKey.Tags: ATTRIBUTE_ELEMENT, HedSectionKey.UnitClasses: ATTRIBUTE_ELEMENT, HedSectionKey.Units: ATTRIBUTE_ELEMENT, HedSectionKey.UnitModifiers: ATTRIBUTE_ELEMENT, diff --git a/hed/schema/schema_validation_util.py b/hed/schema/schema_validation_util.py index aaf7cccea..7805a3977 100644 --- a/hed/schema/schema_validation_util.py +++ b/hed/schema/schema_validation_util.py @@ -140,7 +140,7 @@ def find_rooted_entry(tag_entry, schema, loading_merged): f'Found rooted tag \'{tag_entry.short_tag_name}\' as a root node in a merged schema.', schema.filename) - rooted_entry = schema.all_tags.get(rooted_tag) + rooted_entry = schema.tags.get(rooted_tag) if not rooted_entry or rooted_entry.has_attribute(constants.HedKey.InLibrary): raise HedFileError(HedExceptions.ROOTED_TAG_DOES_NOT_EXIST, f"Rooted tag '{tag_entry.short_tag_name}' not found in paired standard schema", diff --git a/hed/tools/util/schema_util.py b/hed/tools/util/schema_util.py index fb30d41a1..f14954d4f 100644 --- a/hed/tools/util/schema_util.py +++ b/hed/tools/util/schema_util.py @@ -11,7 +11,7 @@ def flatten_schema(hed_schema, skip_non_tag=False): """ children, parents, descriptions = [], [], [] for section in hed_schema._sections.values(): - if skip_non_tag and section.section_key != HedSectionKey.AllTags: + if skip_non_tag and section.section_key != HedSectionKey.Tags: continue for entry in section.all_entries: if entry.has_attribute(HedKey.TakesValue): diff --git a/tests/schema/test_hed_schema.py b/tests/schema/test_hed_schema.py index 4c30e1c52..f1b992511 100644 --- a/tests/schema/test_hed_schema.py +++ b/tests/schema/test_hed_schema.py @@ -137,7 +137,7 @@ def test_has_duplicate_tags(self): self.assertFalse(self.hed_schema_3g._has_duplicate_tags) def test_short_tag_mapping(self): - self.assertEqual(len(self.hed_schema_3g.all_tags.keys()), 1110) + self.assertEqual(len(self.hed_schema_3g.tags.keys()), 1110) def test_schema_compliance(self): warnings = self.hed_schema_group.check_compliance(True) diff --git a/tests/schema/test_hed_schema_io.py b/tests/schema/test_hed_schema_io.py index 2ef6987ed..db255fb04 100644 --- a/tests/schema/test_hed_schema_io.py +++ b/tests/schema/test_hed_schema_io.py @@ -235,10 +235,10 @@ def test_saving_bad_sort(self): self.assertEqual(loaded_schema, reloaded_schema) def _base_added_class_tests(self, schema): - tag_entry = schema.all_tags["Modulator"] + tag_entry = schema.tags["Modulator"] self.assertEqual(tag_entry.attributes["suggestedTag"], "Event") - tag_entry = schema.all_tags["Sleep-modulator"] + tag_entry = schema.tags["Sleep-modulator"] self.assertEqual(tag_entry.attributes["relatedTag"], "Sensory-event") unit_class_entry = schema.unit_classes["weightUnits"] diff --git a/tests/schema/test_schema_attribute_validators.py b/tests/schema/test_schema_attribute_validators.py index 67a25efb1..e3753c03e 100644 --- a/tests/schema/test_schema_attribute_validators.py +++ b/tests/schema/test_schema_attribute_validators.py @@ -11,28 +11,28 @@ def setUpClass(cls): cls.hed_schema = schema.load_schema_version("8.1.0") def test_util_placeholder(self): - tag_entry = self.hed_schema.all_tags["Event"] + tag_entry = self.hed_schema.tags["Event"] attribute_name = "unitClass" self.assertTrue(schema_attribute_validators.tag_is_placeholder_check(self.hed_schema, tag_entry, attribute_name)) attribute_name = "unitClass" - tag_entry = self.hed_schema.all_tags["Age/#"] + tag_entry = self.hed_schema.tags["Age/#"] self.assertFalse(schema_attribute_validators.tag_is_placeholder_check(self.hed_schema, tag_entry, attribute_name)) def test_util_suggested(self): - tag_entry = self.hed_schema.all_tags["Event/Sensory-event"] + tag_entry = self.hed_schema.tags["Event/Sensory-event"] attribute_name = "suggestedTag" self.assertFalse(schema_attribute_validators.tag_exists_check(self.hed_schema, tag_entry, attribute_name)) - tag_entry = self.hed_schema.all_tags["Property"] + tag_entry = self.hed_schema.tags["Property"] self.assertFalse(schema_attribute_validators.tag_exists_check(self.hed_schema, tag_entry, attribute_name)) tag_entry = copy.deepcopy(tag_entry) tag_entry.attributes["suggestedTag"] = "InvalidSuggestedTag" self.assertTrue(schema_attribute_validators.tag_exists_check(self.hed_schema, tag_entry, attribute_name)) def test_util_rooted(self): - tag_entry = self.hed_schema.all_tags["Event"] + tag_entry = self.hed_schema.tags["Event"] attribute_name = "rooted" self.assertFalse(schema_attribute_validators.tag_exists_base_schema_check(self.hed_schema, tag_entry, attribute_name)) - tag_entry = self.hed_schema.all_tags["Property"] + tag_entry = self.hed_schema.tags["Property"] self.assertFalse(schema_attribute_validators.tag_exists_base_schema_check(self.hed_schema, tag_entry, attribute_name)) tag_entry = copy.deepcopy(tag_entry) tag_entry.attributes["rooted"] = "Event" diff --git a/tests/schema/test_schema_compare.py b/tests/schema/test_schema_compare.py new file mode 100644 index 000000000..b20c65ba6 --- /dev/null +++ b/tests/schema/test_schema_compare.py @@ -0,0 +1,105 @@ +import unittest +import os +import io + +from hed.schema import HedKey, HedSectionKey, from_string +from hed.schema.schema_compare import compare_schemas, find_matching_tags, \ + pretty_print_diff_all, pretty_print_missing_all, compare_differences + + +class TestSchemaComparison(unittest.TestCase): + + library_schema_start = """HED library="testcomparison" version="1.1.0" withStandard="8.2.0" unmerged="true" + +'''Prologue''' + +!# start schema + +""" + + library_schema_end = """ +!# end schema + +!# end hed + """ + + def _get_test_schema(self, node_lines): + library_schema_string = self.library_schema_start + "\n".join(node_lines) + self.library_schema_end + test_schema = from_string(library_schema_string, ".mediawiki") + + return test_schema + + def load_schema1(self): + test_nodes = ["'''TestNode''' [This is a simple test node]\n", + " *TestNode2", + " *TestNode3", + " *TestNode4" + ] + return self._get_test_schema(test_nodes) + + def load_schema2(self): + test_nodes = ["'''TestNode''' [This is a simple test node]\n", + " *TestNode2", + " **TestNode3", + " *TestNode5" + ] + + return self._get_test_schema(test_nodes) + + def test_find_matching_tags(self): + # create entries for schema1 + schema1 = self.load_schema1() + schema2 = self.load_schema2() + + result = find_matching_tags(schema1, schema2) + # Check if the result is correct + self.assertEqual(len(result[HedSectionKey.Tags]), 3) + self.assertIn("TestNode", result[HedSectionKey.Tags]) + self.assertIn("TestNode2", result[HedSectionKey.Tags]) + self.assertIn("TestNode3", result[HedSectionKey.Tags]) + self.assertNotIn("TestNode4", result[HedSectionKey.Tags]) + self.assertNotIn("TestNode5", result[HedSectionKey.Tags]) + + match_string = find_matching_tags(schema1, schema2, return_string=True) + self.assertIsInstance(match_string, str) + print(match_string) + + def test_compare_schemas(self): + schema1 = self.load_schema1() + schema2 = self.load_schema2() + + matches, not_in_schema1, not_in_schema2, unequal_entries = compare_schemas(schema1, schema2) + + # Check if the result is correct + self.assertEqual(len(matches[HedSectionKey.Tags]), 2) # Three matches should be found + self.assertIn("TestNode", matches[HedSectionKey.Tags]) + self.assertIn("TestNode2", matches[HedSectionKey.Tags]) + self.assertNotIn("TestNode3", matches[HedSectionKey.Tags]) + + self.assertEqual(len(not_in_schema2[HedSectionKey.Tags]), 1) # One tag not in schema2 + self.assertIn("TestNode4", not_in_schema2[HedSectionKey.Tags]) # "TestNode4" is not in schema2 + + self.assertEqual(len(not_in_schema1[HedSectionKey.Tags]), 1) # One tag not in schema1 + self.assertIn("TestNode5", not_in_schema1[HedSectionKey.Tags]) # "TestNode5" is not in schema1 + + self.assertEqual(len(unequal_entries[HedSectionKey.Tags]), 1) # No unequal entries should be found + self.assertIn("TestNode3", unequal_entries[HedSectionKey.Tags]) + + def test_compare_differences(self): + schema1 = self.load_schema1() + schema2 = self.load_schema2() + + not_in_schema1, not_in_schema2, unequal_entries = compare_differences(schema1, schema2) + + self.assertEqual(len(not_in_schema2[HedSectionKey.Tags]), 1) # One tag not in schema2 + self.assertIn("TestNode4", not_in_schema2[HedSectionKey.Tags]) # "TestNode4" is not in schema2 + + self.assertEqual(len(not_in_schema1[HedSectionKey.Tags]), 1) # One tag not in schema1 + self.assertIn("TestNode5", not_in_schema1[HedSectionKey.Tags]) # "TestNode5" is not in schema1 + + self.assertEqual(len(unequal_entries[HedSectionKey.Tags]), 1) # No unequal entries should be found + self.assertIn("TestNode3", unequal_entries[HedSectionKey.Tags]) + + diff_string = compare_differences(schema1, schema2, return_string=True) + self.assertIsInstance(diff_string, str) + print(diff_string) diff --git a/tests/schema/test_schema_compliance.py b/tests/schema/test_schema_compliance.py index 467d34f7a..49f6afe02 100644 --- a/tests/schema/test_schema_compliance.py +++ b/tests/schema/test_schema_compliance.py @@ -1,8 +1,6 @@ import unittest import os - - from hed import schema