diff --git a/hed/models/string_util.py b/hed/models/string_util.py
new file mode 100644
index 000000000..9f6023223
--- /dev/null
+++ b/hed/models/string_util.py
@@ -0,0 +1,59 @@
+from hed.models.hed_string import HedString
+
+
+def split_base_tags(hed_string, base_tags, remove_group=False):
+ """ Splits a HedString object into two separate HedString objects based on the presence of base tags.
+
+ Args:
+ hed_string (HedString): The input HedString object to be split.
+ base_tags (list of str): A list of strings representing the base tags.
+ This is matching the base tag NOT all the terms above it.
+ remove_group (bool, optional): Flag indicating whether to remove the parent group. Defaults to False.
+
+ Returns:
+ tuple: A tuple containing two HedString objects:
+ - The first HedString object contains the remaining tags from hed_string.
+ - The second HedString object contains the tags from hed_string that match the base_tags.
+ """
+
+ base_tags = [tag.lower() for tag in base_tags]
+ include_groups = 0
+ if remove_group:
+ include_groups = 2
+ found_things = hed_string.find_tags(base_tags, recursive=True, include_groups=include_groups)
+ if remove_group:
+ found_things = [tag if isinstance(group, HedString) else group for tag, group in found_things]
+
+ if found_things:
+ hed_string.remove(found_things)
+
+ return hed_string, HedString("", hed_string._schema, _contents=found_things)
+
+
+def split_def_tags(hed_string, def_names, remove_group=False):
+ """ Splits a HedString object into two separate HedString objects based on the presence of wildcard tags.
+
+ This does NOT handle def-expand tags currently.
+
+ Args:
+ hed_string (HedString): The input HedString object to be split.
+ def_names (list of str): A list of def names to search for. Can optionally include a value.
+ remove_group (bool, optional): Flag indicating whether to remove the parent group. Defaults to False.
+
+ Returns:
+ tuple: A tuple containing two HedString objects:
+ - The first HedString object contains the remaining tags from hed_string.
+ - The second HedString object contains the tags from hed_string that match the def_names.
+ """
+ include_groups = 0
+ if remove_group:
+ include_groups = 2
+ wildcard_tags = [f"def/{def_name}".lower() for def_name in def_names]
+ found_things = hed_string.find_wildcard_tags(wildcard_tags, recursive=True, include_groups=include_groups)
+ if remove_group:
+ found_things = [tag if isinstance(group, HedString) else group for tag, group in found_things]
+
+ if found_things:
+ hed_string.remove(found_things)
+
+ return hed_string, HedString("", hed_string._schema, _contents=found_things)
diff --git a/hed/schema/hed_schema.py b/hed/schema/hed_schema.py
index c0f1d2589..7b71b2bb1 100644
--- a/hed/schema/hed_schema.py
+++ b/hed/schema/hed_schema.py
@@ -86,13 +86,13 @@ def merged(self):
return not self.header_attributes.get(constants.UNMERGED_ATTRIBUTE, "")
@property
- def all_tags(self):
+ def tags(self):
""" Return the tag schema section.
Returns:
HedSchemaTagSection: The tag section.
"""
- return self._sections[HedSectionKey.AllTags]
+ return self._sections[HedSectionKey.Tags]
@property
def unit_classes(self):
@@ -354,7 +354,7 @@ def check_compliance(self, check_for_warnings=True, name=None, error_handler=Non
from hed.schema import schema_compliance
return schema_compliance.check_compliance(self, check_for_warnings, name, error_handler)
- def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
+ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags):
""" Return tag entries with the given attribute.
Parameters:
@@ -370,7 +370,7 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
return self._sections[key_class].get_entries_with_attribute(attribute, return_name_only=True,
schema_namespace=self._namespace)
- def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""):
+ def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""):
""" Return the schema entry for this tag, if one exists.
Parameters:
@@ -378,12 +378,12 @@ def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=
This will not handle extensions or similar.
If this is a tag, it can have a schema namespace, but it's not required
key_class (HedSectionKey or str): The type of entry to return.
- schema_namespace (str): Only used on AllTags. If incorrect, will return None.
+ schema_namespace (str): Only used on Tags. If incorrect, will return None.
Returns:
HedSchemaEntry: The schema entry for the given tag.
"""
- if key_class == HedSectionKey.AllTags:
+ if key_class == HedSectionKey.Tags:
if schema_namespace != self._namespace:
return None
if name.startswith(self._namespace):
@@ -415,7 +415,7 @@ def find_tag_entry(self, tag, schema_namespace=""):
# ===============================================
# Private utility functions for getting/finding tags
# ===============================================
- def _get_tag_entry(self, name, key_class=HedSectionKey.AllTags):
+ def _get_tag_entry(self, name, key_class=HedSectionKey.Tags):
""" Return the schema entry for this tag, if one exists.
Parameters:
@@ -524,7 +524,7 @@ def _validate_remaining_terms(self, tag, working_tag, prefix_tag_adj, current_sl
tag,
index_in_tag=word_start_index,
index_in_tag_end=word_start_index + len(name),
- expected_parent_tag=self.all_tags[name].name)
+ expected_parent_tag=self.tags[name].name)
raise self._TagIdentifyError(error)
word_start_index += len(name) + 1
@@ -533,7 +533,7 @@ def _validate_remaining_terms(self, tag, working_tag, prefix_tag_adj, current_sl
# ===============================================
def finalize_dictionaries(self):
""" Call to finish loading. """
- self._has_duplicate_tags = bool(self.all_tags.duplicate_names)
+ self._has_duplicate_tags = bool(self.tags.duplicate_names)
self._update_all_entries()
def _update_all_entries(self):
@@ -568,13 +568,13 @@ def get_desc_iter(self):
if tag_entry.description:
yield tag_entry.name, tag_entry.description
- def get_tag_description(self, tag_name, key_class=HedSectionKey.AllTags):
+ def get_tag_description(self, tag_name, key_class=HedSectionKey.Tags):
""" Return the description associated with the tag.
Parameters:
tag_name (str): A hed tag name(or unit/unit modifier etc) with proper capitalization.
key_class (str): A string indicating type of description (e.g. All tags, Units, Unit modifier).
- The default is HedSectionKey.AllTags.
+ The default is HedSectionKey.Tags.
Returns:
str: A description of the specified tag.
@@ -595,7 +595,7 @@ def get_all_schema_tags(self, return_last_term=False):
"""
final_list = []
- for lower_tag, tag_entry in self.all_tags.items():
+ for lower_tag, tag_entry in self.tags.items():
if return_last_term:
final_list.append(tag_entry.name.split('/')[-1])
else:
@@ -636,7 +636,7 @@ def get_tag_attribute_names(self):
and not tag_entry.has_attribute(HedKey.UnitModifierProperty)
and not tag_entry.has_attribute(HedKey.ValueClassProperty)}
- def get_all_tag_attributes(self, tag_name, key_class=HedSectionKey.AllTags):
+ def get_all_tag_attributes(self, tag_name, key_class=HedSectionKey.Tags):
""" Gather all attributes for a given tag name.
Parameters:
@@ -670,7 +670,7 @@ def _create_empty_sections():
dictionaries[HedSectionKey.Units] = HedSchemaSection(HedSectionKey.Units)
dictionaries[HedSectionKey.UnitClasses] = HedSchemaUnitClassSection(HedSectionKey.UnitClasses)
dictionaries[HedSectionKey.ValueClasses] = HedSchemaSection(HedSectionKey.ValueClasses)
- dictionaries[HedSectionKey.AllTags] = HedSchemaTagSection(HedSectionKey.AllTags, case_sensitive=False)
+ dictionaries[HedSectionKey.Tags] = HedSchemaTagSection(HedSectionKey.Tags, case_sensitive=False)
return dictionaries
@@ -717,7 +717,7 @@ def _get_attributes_for_section(self, key_class):
dict or HedSchemaSection: A dict of all the attributes and this section.
"""
- if key_class == HedSectionKey.AllTags:
+ if key_class == HedSectionKey.Tags:
return self.get_tag_attribute_names()
elif key_class == HedSectionKey.Attributes:
prop_added_dict = {key: value for key, value in self._sections[HedSectionKey.Properties].items()}
diff --git a/hed/schema/hed_schema_base.py b/hed/schema/hed_schema_base.py
index b0e29ebcc..6651077e0 100644
--- a/hed/schema/hed_schema_base.py
+++ b/hed/schema/hed_schema_base.py
@@ -55,7 +55,7 @@ def valid_prefixes(self):
raise NotImplemented("This function must be implemented in the baseclass")
@abstractmethod
- def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
+ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags):
""" Return tag entries with the given attribute.
Parameters:
@@ -72,7 +72,7 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
# todo: maybe tweak this API so you don't have to pass in library namespace?
@abstractmethod
- def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""):
+ def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""):
""" Return the schema entry for this tag, if one exists.
Parameters:
@@ -80,7 +80,7 @@ def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=
This will not handle extensions or similar.
If this is a tag, it can have a schema namespace, but it's not required
key_class (HedSectionKey or str): The type of entry to return.
- schema_namespace (str): Only used on AllTags. If incorrect, will return None.
+ schema_namespace (str): Only used on Tags. If incorrect, will return None.
Returns:
HedSchemaEntry: The schema entry for the given tag.
diff --git a/hed/schema/hed_schema_constants.py b/hed/schema/hed_schema_constants.py
index 03a783fd4..e74f1b70f 100644
--- a/hed/schema/hed_schema_constants.py
+++ b/hed/schema/hed_schema_constants.py
@@ -4,7 +4,7 @@ class HedSectionKey(Enum):
""" Kegs designating specific sections in a HedSchema object.
"""
# overarching category listing all tags
- AllTags = 'tags'
+ Tags = 'tags'
# Overarching category listing all unit classes
UnitClasses = 'unitClasses'
# Overarching category listing all units(not divided by type)
diff --git a/hed/schema/hed_schema_group.py b/hed/schema/hed_schema_group.py
index 96187b73f..ae0ac2b81 100644
--- a/hed/schema/hed_schema_group.py
+++ b/hed/schema/hed_schema_group.py
@@ -101,7 +101,7 @@ def check_compliance(self, check_for_warnings=True, name=None, error_handler=Non
issues_list += schema.check_compliance(check_for_warnings, name, error_handler)
return issues_list
- def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
+ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.Tags):
""" Return tag entries with the given attribute.
Parameters:
@@ -114,12 +114,12 @@ def get_tags_with_attribute(self, attribute, key_class=HedSectionKey.AllTags):
Notes:
- The result is cached so will be fast after first call.
"""
- all_tags = set()
+ tags = set()
for schema in self._schemas.values():
- all_tags.update(schema.get_tags_with_attribute(attribute, key_class))
- return list(all_tags)
+ tags.update(schema.get_tags_with_attribute(attribute, key_class))
+ return list(tags)
- def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=""):
+ def get_tag_entry(self, name, key_class=HedSectionKey.Tags, schema_namespace=""):
""" Return the schema entry for this tag, if one exists.
Parameters:
@@ -127,7 +127,7 @@ def get_tag_entry(self, name, key_class=HedSectionKey.AllTags, schema_namespace=
This will not handle extensions or similar.
If this is a tag, it can have a schema namespace, but it's not required
key_class (HedSectionKey or str): The type of entry to return.
- schema_namespace (str): Only used on AllTags. If incorrect, will return None.
+ schema_namespace (str): Only used on Tags. If incorrect, will return None.
Returns:
HedSchemaEntry: The schema entry for the given tag.
diff --git a/hed/schema/hed_schema_section.py b/hed/schema/hed_schema_section.py
index 2da2ff54b..9b6ba657e 100644
--- a/hed/schema/hed_schema_section.py
+++ b/hed/schema/hed_schema_section.py
@@ -9,7 +9,7 @@
HedSectionKey.Units: UnitEntry,
HedSectionKey.UnitClasses: UnitClassEntry,
HedSectionKey.ValueClasses: HedSchemaEntry,
- HedSectionKey.AllTags: HedTagEntry,
+ HedSectionKey.Tags: HedTagEntry,
}
diff --git a/hed/schema/schema_attribute_validators.py b/hed/schema/schema_attribute_validators.py
index 2fa23d1db..47dc7410b 100644
--- a/hed/schema/schema_attribute_validators.py
+++ b/hed/schema/schema_attribute_validators.py
@@ -50,7 +50,7 @@ def tag_exists_check(hed_schema, tag_entry, attribute_name):
possible_tags = tag_entry.attributes.get(attribute_name, "")
split_tags = possible_tags.split(",")
for org_tag in split_tags:
- if org_tag and org_tag not in hed_schema.all_tags:
+ if org_tag and org_tag not in hed_schema.tags:
issues += ErrorHandler.format_error(ValidationErrors.NO_VALID_TAG_FOUND,
org_tag,
index_in_tag=0,
@@ -72,7 +72,7 @@ def tag_exists_base_schema_check(hed_schema, tag_entry, attribute_name):
"""
issues = []
rooted_tag = tag_entry.attributes.get(attribute_name, "")
- if rooted_tag and rooted_tag not in hed_schema.all_tags:
+ if rooted_tag and rooted_tag not in hed_schema.tags:
issues += ErrorHandler.format_error(ValidationErrors.NO_VALID_TAG_FOUND,
rooted_tag,
index_in_tag=0,
diff --git a/hed/schema/schema_compare.py b/hed/schema/schema_compare.py
new file mode 100644
index 000000000..49d5c72b6
--- /dev/null
+++ b/hed/schema/schema_compare.py
@@ -0,0 +1,202 @@
+from hed.schema.hed_schema import HedSchema, HedKey
+from hed.schema.hed_schema_constants import HedSectionKey
+
+
+def find_matching_tags(schema1, schema2, return_string=False):
+ """
+ Compare the tags in two library schemas. This finds tags with the same term.
+
+ Parameters:
+ schema1 (HedSchema): The first schema to be compared.
+ schema2 (HedSchema): The second schema to be compared.
+ return_string (bool): Return this as a string if true
+
+ Returns:
+ dict or str: A dictionary containing matching entries in the Tags section of both schemas.
+ """
+ matches, _, _, unequal_entries = compare_schemas(schema1, schema2)
+
+ for section_key, section_dict in matches.items():
+ section_dict.update(unequal_entries[section_key])
+
+ if return_string:
+ return "\n".join([pretty_print_diff_all(entries, prompt="Found matching node ") for entries in matches.values()])
+ return matches
+
+
+def compare_differences(schema1, schema2, return_string=False, attribute_filter=None):
+ """
+ Compare the tags in two schemas, this finds any differences
+
+ Parameters:
+ schema1 (HedSchema): The first schema to be compared.
+ schema2 (HedSchema): The second schema to be compared.
+ return_string (bool): Return this as a string if true
+ attribute_filter (str, optional): The attribute to filter entries by.
+ Entries without this attribute are skipped.
+ The most common use would be HedKey.InLibrary
+ If it evaluates to False, no filtering is performed.
+ Returns:
+ tuple or str: A tuple containing three dictionaries:
+ - not_in_schema1(dict): Entries present in schema2 but not in schema1.
+ - not_in_schema2(dict): Entries present in schema1 but not in schema2.
+ - unequal_entries(dict): Entries that differ between the two schemas.
+ - or a formatted string of the differences
+ """
+ _, not_in_1, not_in_2, unequal_entries = compare_schemas(schema1, schema2, attribute_filter=attribute_filter)
+
+ if return_string:
+ str1 = "\n".join([pretty_print_diff_all(entries) for entries in unequal_entries.values()]) + "\n"
+ str2 = "\n".join([pretty_print_missing_all(entries, "Schema1") for entries in not_in_1.values()]) + "\n"
+ str3 = "\n".join([pretty_print_missing_all(entries, "Schema2") for entries in not_in_2.values()])
+ return str1 + str2 + str3
+ return not_in_1, not_in_2, unequal_entries
+
+
+def compare_schemas(schema1, schema2, attribute_filter=HedKey.InLibrary, sections=(HedSectionKey.Tags,)):
+ """
+ Compare two schemas section by section.
+ The function records matching entries, entries present in one schema but not in the other, and unequal entries.
+
+ Parameters:
+ schema1 (HedSchema): The first schema to be compared.
+ schema2 (HedSchema): The second schema to be compared.
+ attribute_filter (str, optional): The attribute to filter entries by.
+ Entries without this attribute are skipped.
+ If it evaluates to False, no filtering is performed.
+ sections(list): the list of sections to compare. By default, just the tags section.
+
+ Returns:
+ tuple: A tuple containing four dictionaries:
+ - matches(dict): Entries present in both schemas and are equal.
+ - not_in_schema1(dict): Entries present in schema2 but not in schema1.
+ - not_in_schema2(dict): Entries present in schema1 but not in schema2.
+ - unequal_entries(dict): Entries present in both schemas but are not equal.
+ """
+ # Result dictionaries to hold matches, keys not in schema2, keys not in schema1, and unequal entries
+ matches = {}
+ not_in_schema2 = {}
+ not_in_schema1 = {}
+ unequal_entries = {}
+
+ # Iterate over keys in HedSectionKey
+ for section_key in HedSectionKey:
+ if not sections or section_key not in sections:
+ continue
+ # Dictionaries to record (short_tag_name or name): entry pairs
+ dict1 = {}
+ dict2 = {}
+
+ section1 = schema1[section_key]
+ section2 = schema2[section_key]
+
+ attribute = 'short_tag_name' if section_key == HedSectionKey.Tags else 'name'
+
+ for entry in section1.all_entries:
+ if not attribute_filter or entry.has_attribute(attribute_filter):
+ dict1[getattr(entry, attribute)] = entry
+
+ for entry in section2.all_entries:
+ if not attribute_filter or entry.has_attribute(attribute_filter):
+ dict2[getattr(entry, attribute)] = entry
+
+ # Find keys present in dict1 but not in dict2, and vice versa
+ not_in_schema2[section_key] = {key: dict1[key] for key in dict1 if key not in dict2}
+ not_in_schema1[section_key] = {key: dict2[key] for key in dict2 if key not in dict1}
+
+ # Find keys present in both but with unequal entries
+ unequal_entries[section_key] = {key: (dict1[key], dict2[key]) for key in dict1
+ if key in dict2 and dict1[key] != dict2[key]}
+
+ # Find matches
+ matches[section_key] = {key: (dict1[key], dict2[key]) for key in dict1
+ if key in dict2 and dict1[key] == dict2[key]}
+
+ return matches, not_in_schema1, not_in_schema2, unequal_entries
+
+
+def pretty_print_diff_entry(entry1, entry2):
+ """
+ Returns the differences between two HedSchemaEntry objects as a list of strings
+
+ Parameters:
+ entry1 (HedSchemaEntry): The first entry.
+ entry2 (HedSchemaEntry): The second entry.
+
+ Returns:
+ diff_lines(list): the differences as a list of strings
+ """
+ output = []
+ # Checking if both entries have the same name
+ if entry1.name != entry2.name:
+ output.append(f"\tName differs: '{entry1.name}' vs '{entry2.name}'")
+
+ # Checking if both entries have the same description
+ if entry1.description != entry2.description:
+ output.append(f"\tDescription differs: '{entry1.description}' vs '{entry2.description}'")
+
+ # Comparing attributes
+ for attr in set(entry1.attributes.keys()).union(entry2.attributes.keys()):
+ if entry1.attributes.get(attr) != entry2.attributes.get(attr):
+ output.append(f"\tAttribute '{attr}' differs: '{entry1.attributes.get(attr)}' vs '{entry2.attributes.get(attr)}'")
+
+ return output
+
+
+def pretty_print_entry(entry):
+ """ Returns the contents of a HedSchemaEntry object as a list of strings.
+
+ Parameters:
+ entry (HedSchemaEntry): The HedSchemaEntry object to be displayed.
+
+ Returns:
+ List of strings representing the entry.
+ """
+ # Initialize the list with the name of the entry
+ output = [f"\tName: {entry.name}"]
+
+ # Add the description to the list if it exists
+ if entry.description is not None:
+ output.append(f"\tDescription: {entry.description}")
+
+ # Iterate over all attributes and add them to the list
+ for attr_key, attr_value in entry.attributes.items():
+ output.append(f"\tAttribute: {attr_key} - Value: {attr_value}")
+
+ return output
+
+
+def pretty_print_diff_all(entries, prompt="Differences for "):
+ """
+ Formats the differences between pairs of HedSchemaEntry objects.
+
+ Parameters:
+ entries (dict): A dictionary where each key maps to a pair of HedSchemaEntry objects.
+ prompt(str): The prompt for each entry
+ Returns:
+ diff_string(str): The differences found in the dict
+ """
+ output = []
+ for key, (entry1, entry2) in entries.items():
+ output.append(f"{prompt}'{key}':")
+ output += pretty_print_diff_entry(entry1, entry2)
+
+ return "\n".join(output)
+
+
+def pretty_print_missing_all(entries, schema_name):
+ """
+ Formats the missing entries from schema_name.
+
+ Parameters:
+ entries (dict): A dictionary where each key maps to a pair of HedSchemaEntry objects.
+ schema_name(str): The name these entries are missing from
+ Returns:
+ diff_string(str): The differences found in the dict
+ """
+ output = []
+ for key, entry in entries.items():
+ output.append(f"'{key}' not in {schema_name}':")
+ output += pretty_print_entry(entry)
+
+ return "\n".join(output)
\ No newline at end of file
diff --git a/hed/schema/schema_compliance.py b/hed/schema/schema_compliance.py
index 20db73376..59770e6b2 100644
--- a/hed/schema/schema_compliance.py
+++ b/hed/schema/schema_compliance.py
@@ -114,3 +114,5 @@ def check_invalid_chars(self):
for tag_name, desc in self.hed_schema.get_desc_iter():
issues_list += validate_schema_description(tag_name, desc)
return issues_list
+
+
diff --git a/hed/schema/schema_io/schema2base.py b/hed/schema/schema_io/schema2base.py
index 7ed8aceaf..280562fd9 100644
--- a/hed/schema/schema_io/schema2base.py
+++ b/hed/schema/schema_io/schema2base.py
@@ -41,7 +41,7 @@ def process_schema(self, hed_schema, save_merged=False):
self._output_header(hed_schema.get_save_header_attributes(self._save_merged), hed_schema.prologue)
- self._output_tags(hed_schema.all_tags)
+ self._output_tags(hed_schema.tags)
self._output_units(hed_schema.unit_classes)
self._output_section(hed_schema, HedSectionKey.UnitModifiers)
self._output_section(hed_schema, HedSectionKey.ValueClasses)
@@ -69,13 +69,13 @@ def _write_tag_entry(self, tag_entry, parent=None, level=0):
def _write_entry(self, entry, parent_node, include_props=True):
raise NotImplementedError("This needs to be defined in the subclass")
- def _output_tags(self, all_tags):
- schema_node = self._start_section(HedSectionKey.AllTags)
+ def _output_tags(self, tags):
+ schema_node = self._start_section(HedSectionKey.Tags)
# This assumes .all_entries is sorted in a reasonable way for output.
level_adj = 0
all_nodes = {} # List of all nodes we've written out.
- for tag_entry in all_tags.all_entries:
+ for tag_entry in tags.all_entries:
if self._should_skip(tag_entry):
continue
tag = tag_entry.name
diff --git a/hed/schema/schema_io/schema2xml.py b/hed/schema/schema_io/schema2xml.py
index 974480347..c25e38074 100644
--- a/hed/schema/schema_io/schema2xml.py
+++ b/hed/schema/schema_io/schema2xml.py
@@ -52,7 +52,7 @@ def _write_tag_entry(self, tag_entry, parent_node=None, level=0):
SubElement
The added node
"""
- key_class = HedSectionKey.AllTags
+ key_class = HedSectionKey.Tags
tag_element = xml_constants.ELEMENT_NAMES[key_class]
tag_description = tag_entry.description
tag_attributes = tag_entry.attributes
diff --git a/hed/schema/schema_io/wiki2schema.py b/hed/schema/schema_io/wiki2schema.py
index a66dc7b2e..1b91da9d1 100644
--- a/hed/schema/schema_io/wiki2schema.py
+++ b/hed/schema/schema_io/wiki2schema.py
@@ -290,7 +290,7 @@ def _read_schema(self, lines):
lines: [(int, str)]
Lines for this section
"""
- self._schema._initialize_attributes(HedSectionKey.AllTags)
+ self._schema._initialize_attributes(HedSectionKey.Tags)
parent_tags = []
level_adj = 0
for line_number, line in lines:
@@ -322,7 +322,7 @@ def _read_schema(self, lines):
self._add_fatal_error(line_number, line, e.message, e.code)
continue
- tag_entry = self._add_to_dict(line_number, line, tag_entry, HedSectionKey.AllTags)
+ tag_entry = self._add_to_dict(line_number, line, tag_entry, HedSectionKey.Tags)
parent_tags.append(tag_entry.short_tag_name)
@@ -594,7 +594,7 @@ def _add_tag_line(self, parent_tags, line_number, tag_line):
long_tag_name = "/".join(parent_tags) + "/" + tag_name
else:
long_tag_name = tag_name
- return self._create_entry(line_number, tag_line, HedSectionKey.AllTags, long_tag_name)
+ return self._create_entry(line_number, tag_line, HedSectionKey.Tags, long_tag_name)
self._add_fatal_error(line_number, tag_line)
return None
diff --git a/hed/schema/schema_io/wiki_constants.py b/hed/schema/schema_io/wiki_constants.py
index 2f7020654..7d1bf31ad 100644
--- a/hed/schema/schema_io/wiki_constants.py
+++ b/hed/schema/schema_io/wiki_constants.py
@@ -14,7 +14,7 @@
EPILOGUE_SECTION_ELEMENT = "'''Epilogue'''"
wiki_section_headers = {
- HedSectionKey.AllTags: START_HED_STRING,
+ HedSectionKey.Tags: START_HED_STRING,
HedSectionKey.UnitClasses: UNIT_CLASS_STRING,
HedSectionKey.Units: None,
HedSectionKey.UnitModifiers: UNIT_MODIFIER_STRING,
diff --git a/hed/schema/schema_io/xml2schema.py b/hed/schema/schema_io/xml2schema.py
index 6db98eb68..437e11fff 100644
--- a/hed/schema/schema_io/xml2schema.py
+++ b/hed/schema/schema_io/xml2schema.py
@@ -113,7 +113,7 @@ def _populate_tag_dictionaries(self):
A dictionary of dictionaries that has been populated with dictionaries associated with tag attributes.
"""
- self._schema._initialize_attributes(HedSectionKey.AllTags)
+ self._schema._initialize_attributes(HedSectionKey.Tags)
tag_elements = self._get_elements_by_name("node")
loading_from_chain = ""
loading_from_chain_short = ""
@@ -125,7 +125,7 @@ def _populate_tag_dictionaries(self):
loading_from_chain = ""
else:
tag = tag.replace(loading_from_chain_short, loading_from_chain)
- tag_entry = self._parse_node(tag_element, HedSectionKey.AllTags, tag)
+ tag_entry = self._parse_node(tag_element, HedSectionKey.Tags, tag)
rooted_entry = schema_validation_util.find_rooted_entry(tag_entry, self._schema, self._loading_merged)
if rooted_entry:
@@ -133,9 +133,9 @@ def _populate_tag_dictionaries(self):
loading_from_chain_short = tag_entry.short_tag_name
tag = tag.replace(loading_from_chain_short, loading_from_chain)
- tag_entry = self._parse_node(tag_element, HedSectionKey.AllTags, tag)
+ tag_entry = self._parse_node(tag_element, HedSectionKey.Tags, tag)
- self._add_to_dict(tag_entry, HedSectionKey.AllTags)
+ self._add_to_dict(tag_entry, HedSectionKey.Tags)
def _populate_unit_class_dictionaries(self):
"""Populates a dictionary of dictionaries associated with all the unit classes, unit class units, and unit
diff --git a/hed/schema/schema_io/xml_constants.py b/hed/schema/schema_io/xml_constants.py
index 3dbd7e647..7c2f6071d 100644
--- a/hed/schema/schema_io/xml_constants.py
+++ b/hed/schema/schema_io/xml_constants.py
@@ -36,7 +36,7 @@
SECTION_NAMES = {
- HedSectionKey.AllTags: SCHEMA_ELEMENT,
+ HedSectionKey.Tags: SCHEMA_ELEMENT,
HedSectionKey.UnitClasses: UNIT_CLASS_SECTION_ELEMENT,
HedSectionKey.UnitModifiers: UNIT_MODIFIER_SECTION_ELEMENT,
HedSectionKey.ValueClasses: SCHEMA_VALUE_CLASSES_SECTION_ELEMENT,
@@ -46,7 +46,7 @@
ELEMENT_NAMES = {
- HedSectionKey.AllTags: TAG_DEF_ELEMENT,
+ HedSectionKey.Tags: TAG_DEF_ELEMENT,
HedSectionKey.UnitClasses: UNIT_CLASS_DEF_ELEMENT,
HedSectionKey.Units: UNIT_CLASS_UNIT_ELEMENT,
HedSectionKey.UnitModifiers: UNIT_MODIFIER_DEF_ELEMENT,
@@ -57,7 +57,7 @@
ATTRIBUTE_PROPERTY_ELEMENTS = {
- HedSectionKey.AllTags: ATTRIBUTE_ELEMENT,
+ HedSectionKey.Tags: ATTRIBUTE_ELEMENT,
HedSectionKey.UnitClasses: ATTRIBUTE_ELEMENT,
HedSectionKey.Units: ATTRIBUTE_ELEMENT,
HedSectionKey.UnitModifiers: ATTRIBUTE_ELEMENT,
diff --git a/hed/schema/schema_validation_util.py b/hed/schema/schema_validation_util.py
index aaf7cccea..7805a3977 100644
--- a/hed/schema/schema_validation_util.py
+++ b/hed/schema/schema_validation_util.py
@@ -140,7 +140,7 @@ def find_rooted_entry(tag_entry, schema, loading_merged):
f'Found rooted tag \'{tag_entry.short_tag_name}\' as a root node in a merged schema.',
schema.filename)
- rooted_entry = schema.all_tags.get(rooted_tag)
+ rooted_entry = schema.tags.get(rooted_tag)
if not rooted_entry or rooted_entry.has_attribute(constants.HedKey.InLibrary):
raise HedFileError(HedExceptions.ROOTED_TAG_DOES_NOT_EXIST,
f"Rooted tag '{tag_entry.short_tag_name}' not found in paired standard schema",
diff --git a/hed/tools/util/schema_util.py b/hed/tools/util/schema_util.py
index fb30d41a1..f14954d4f 100644
--- a/hed/tools/util/schema_util.py
+++ b/hed/tools/util/schema_util.py
@@ -11,7 +11,7 @@ def flatten_schema(hed_schema, skip_non_tag=False):
"""
children, parents, descriptions = [], [], []
for section in hed_schema._sections.values():
- if skip_non_tag and section.section_key != HedSectionKey.AllTags:
+ if skip_non_tag and section.section_key != HedSectionKey.Tags:
continue
for entry in section.all_entries:
if entry.has_attribute(HedKey.TakesValue):
diff --git a/tests/models/test_string_util.py b/tests/models/test_string_util.py
new file mode 100644
index 000000000..6819d7bca
--- /dev/null
+++ b/tests/models/test_string_util.py
@@ -0,0 +1,135 @@
+import unittest
+from hed import HedString, load_schema_version
+from hed.models.string_util import split_base_tags, split_def_tags
+import copy
+
+
+class TestHedStringSplit(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.schema = load_schema_version()
+
+ def check_split_base_tags(self, hed_string, base_tags, expected_string, expected_string2):
+ # Test case 1: remove_group=False
+ hed_string_copy = copy.deepcopy(hed_string)
+ remaining_hed, found_hed = split_base_tags(hed_string_copy, base_tags, remove_group=False)
+
+ self.assertIsInstance(remaining_hed, HedString)
+ self.assertIsInstance(found_hed, HedString)
+ self.assertEqual(str(remaining_hed), expected_string)
+
+ self.assertTrue(all(tag in [str(t) for t in found_hed.get_all_tags()] for tag in base_tags))
+ self.assertTrue(all(tag not in [str(t) for t in remaining_hed.get_all_tags()] for tag in base_tags))
+
+ # Test case 2: remove_group=True
+ hed_string_copy = copy.deepcopy(hed_string)
+ remaining_hed, found_hed = split_base_tags(hed_string_copy, base_tags, remove_group=True)
+
+ self.assertIsInstance(remaining_hed, HedString)
+ self.assertIsInstance(found_hed, HedString)
+ self.assertEqual(str(remaining_hed), expected_string2)
+
+ self.assertTrue(all(tag in [str(t) for t in found_hed.get_all_tags()] for tag in base_tags))
+ self.assertTrue(all(tag not in [str(t) for t in remaining_hed.get_all_tags()] for tag in base_tags))
+
+ def test_case_1(self):
+ hed_string = HedString('Memorize,Action,Area', self.schema)
+ base_tags = ['Area', 'Action']
+ expected_string = 'Memorize'
+ expected_string2 = 'Memorize'
+ self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2)
+
+ def test_case_2(self):
+ hed_string = HedString('Area,LightBlue,Handedness', self.schema)
+ base_tags = ['Area', 'LightBlue']
+ expected_string = 'Handedness'
+ expected_string2 = 'Handedness'
+ self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2)
+
+ def test_case_3(self):
+ hed_string = HedString('(Wink,Communicate),Face,HotPink', self.schema)
+ base_tags = ['Wink', 'Face']
+ expected_string = '(Communicate),HotPink'
+ expected_string2 = "HotPink"
+ self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2)
+
+ def test_case_4(self):
+ hed_string = HedString('(Area,(LightBlue,Handedness,(Wink,Communicate))),Face,HotPink', self.schema)
+ base_tags = ['Area', 'LightBlue']
+ expected_string = '((Handedness,(Wink,Communicate))),Face,HotPink'
+ expected_string2 = 'Face,HotPink'
+ self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2)
+
+ def test_case_5(self):
+ hed_string = HedString('(Memorize,(Action,(Area,LightBlue),Handedness),Wink)', self.schema)
+ base_tags = ['Area', 'LightBlue']
+ expected_string = '(Memorize,(Action,Handedness),Wink)'
+ expected_string2 = '(Memorize,(Action,Handedness),Wink)'
+ self.check_split_base_tags(hed_string, base_tags, expected_string, expected_string2)
+
+class TestHedStringSplitDef(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.schema = load_schema_version()
+
+ def check_split_def_tags(self, hed_string, def_names, expected_string, expected_string2):
+ # Test case 1: remove_group=False
+ hed_string_copy = copy.deepcopy(hed_string)
+ remaining_hed, found_hed = split_def_tags(hed_string_copy, def_names, remove_group=False)
+
+ self.assertIsInstance(remaining_hed, HedString)
+ self.assertIsInstance(found_hed, HedString)
+ self.assertEqual(str(remaining_hed), expected_string)
+
+ self.assertTrue(all(tag.short_base_tag == "Def" for tag in found_hed.get_all_tags()))
+ self.assertTrue(all(tag.short_base_tag != "Def" for tag in remaining_hed.get_all_tags()))
+
+ # Test case 2: remove_group=True
+ hed_string_copy = copy.deepcopy(hed_string)
+ remaining_hed, found_hed = split_def_tags(hed_string_copy, def_names, remove_group=True)
+
+ self.assertIsInstance(remaining_hed, HedString)
+ self.assertIsInstance(found_hed, HedString)
+ self.assertEqual(str(remaining_hed), expected_string2)
+
+ #self.assertTrue(all(tag.short_base_tag == "Def" for tag in found_hed.get_all_tags()))
+ self.assertTrue(all(tag.short_base_tag != "Def" for tag in remaining_hed.get_all_tags()))
+
+ def test_case_1(self):
+ hed_string = HedString('Memorize,Action,def/CustomTag1', self.schema)
+ def_names = ['CustomTag1']
+ expected_string = 'Memorize,Action'
+ expected_string2 = 'Memorize,Action'
+ self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2)
+
+ def test_case_2(self):
+ hed_string = HedString('def/CustomTag1,LightBlue,def/CustomTag2/123', self.schema)
+ def_names = ['CustomTag1', 'CustomTag2']
+ expected_string = 'LightBlue'
+ expected_string2 = 'LightBlue'
+ self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2)
+
+ def test_case_3(self):
+ hed_string = HedString('(def/CustomTag1,Communicate),Face,def/CustomTag3/abc', self.schema)
+ def_names = ['CustomTag1', 'CustomTag3']
+ expected_string = '(Communicate),Face'
+ expected_string2 = 'Face'
+ self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2)
+
+ def test_case_4(self):
+ hed_string = HedString('(def/CustomTag1,(LightBlue,def/CustomTag2/123,(Wink,Communicate))),Face,def/CustomTag3/abc', self.schema)
+ def_names = ['CustomTag1', 'CustomTag2', 'CustomTag3']
+ expected_string = '((LightBlue,(Wink,Communicate))),Face'
+ expected_string2 = 'Face'
+ self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2)
+
+ def test_case_5(self):
+ hed_string = HedString('(Memorize,(Action,(def/CustomTag1,LightBlue),def/CustomTag2/123),Wink)', self.schema)
+ def_names = ['CustomTag1', 'CustomTag2']
+ expected_string = '(Memorize,(Action,(LightBlue)),Wink)'
+ expected_string2 = '(Memorize,Wink)'
+ self.check_split_def_tags(hed_string, def_names, expected_string, expected_string2)
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/schema/test_hed_schema.py b/tests/schema/test_hed_schema.py
index 4c30e1c52..f1b992511 100644
--- a/tests/schema/test_hed_schema.py
+++ b/tests/schema/test_hed_schema.py
@@ -137,7 +137,7 @@ def test_has_duplicate_tags(self):
self.assertFalse(self.hed_schema_3g._has_duplicate_tags)
def test_short_tag_mapping(self):
- self.assertEqual(len(self.hed_schema_3g.all_tags.keys()), 1110)
+ self.assertEqual(len(self.hed_schema_3g.tags.keys()), 1110)
def test_schema_compliance(self):
warnings = self.hed_schema_group.check_compliance(True)
diff --git a/tests/schema/test_hed_schema_io.py b/tests/schema/test_hed_schema_io.py
index 2ef6987ed..db255fb04 100644
--- a/tests/schema/test_hed_schema_io.py
+++ b/tests/schema/test_hed_schema_io.py
@@ -235,10 +235,10 @@ def test_saving_bad_sort(self):
self.assertEqual(loaded_schema, reloaded_schema)
def _base_added_class_tests(self, schema):
- tag_entry = schema.all_tags["Modulator"]
+ tag_entry = schema.tags["Modulator"]
self.assertEqual(tag_entry.attributes["suggestedTag"], "Event")
- tag_entry = schema.all_tags["Sleep-modulator"]
+ tag_entry = schema.tags["Sleep-modulator"]
self.assertEqual(tag_entry.attributes["relatedTag"], "Sensory-event")
unit_class_entry = schema.unit_classes["weightUnits"]
diff --git a/tests/schema/test_schema_attribute_validators.py b/tests/schema/test_schema_attribute_validators.py
index 67a25efb1..e3753c03e 100644
--- a/tests/schema/test_schema_attribute_validators.py
+++ b/tests/schema/test_schema_attribute_validators.py
@@ -11,28 +11,28 @@ def setUpClass(cls):
cls.hed_schema = schema.load_schema_version("8.1.0")
def test_util_placeholder(self):
- tag_entry = self.hed_schema.all_tags["Event"]
+ tag_entry = self.hed_schema.tags["Event"]
attribute_name = "unitClass"
self.assertTrue(schema_attribute_validators.tag_is_placeholder_check(self.hed_schema, tag_entry, attribute_name))
attribute_name = "unitClass"
- tag_entry = self.hed_schema.all_tags["Age/#"]
+ tag_entry = self.hed_schema.tags["Age/#"]
self.assertFalse(schema_attribute_validators.tag_is_placeholder_check(self.hed_schema, tag_entry, attribute_name))
def test_util_suggested(self):
- tag_entry = self.hed_schema.all_tags["Event/Sensory-event"]
+ tag_entry = self.hed_schema.tags["Event/Sensory-event"]
attribute_name = "suggestedTag"
self.assertFalse(schema_attribute_validators.tag_exists_check(self.hed_schema, tag_entry, attribute_name))
- tag_entry = self.hed_schema.all_tags["Property"]
+ tag_entry = self.hed_schema.tags["Property"]
self.assertFalse(schema_attribute_validators.tag_exists_check(self.hed_schema, tag_entry, attribute_name))
tag_entry = copy.deepcopy(tag_entry)
tag_entry.attributes["suggestedTag"] = "InvalidSuggestedTag"
self.assertTrue(schema_attribute_validators.tag_exists_check(self.hed_schema, tag_entry, attribute_name))
def test_util_rooted(self):
- tag_entry = self.hed_schema.all_tags["Event"]
+ tag_entry = self.hed_schema.tags["Event"]
attribute_name = "rooted"
self.assertFalse(schema_attribute_validators.tag_exists_base_schema_check(self.hed_schema, tag_entry, attribute_name))
- tag_entry = self.hed_schema.all_tags["Property"]
+ tag_entry = self.hed_schema.tags["Property"]
self.assertFalse(schema_attribute_validators.tag_exists_base_schema_check(self.hed_schema, tag_entry, attribute_name))
tag_entry = copy.deepcopy(tag_entry)
tag_entry.attributes["rooted"] = "Event"
diff --git a/tests/schema/test_schema_compare.py b/tests/schema/test_schema_compare.py
new file mode 100644
index 000000000..b20c65ba6
--- /dev/null
+++ b/tests/schema/test_schema_compare.py
@@ -0,0 +1,105 @@
+import unittest
+import os
+import io
+
+from hed.schema import HedKey, HedSectionKey, from_string
+from hed.schema.schema_compare import compare_schemas, find_matching_tags, \
+ pretty_print_diff_all, pretty_print_missing_all, compare_differences
+
+
+class TestSchemaComparison(unittest.TestCase):
+
+ library_schema_start = """HED library="testcomparison" version="1.1.0" withStandard="8.2.0" unmerged="true"
+
+'''Prologue'''
+
+!# start schema
+
+"""
+
+ library_schema_end = """
+!# end schema
+
+!# end hed
+ """
+
+ def _get_test_schema(self, node_lines):
+ library_schema_string = self.library_schema_start + "\n".join(node_lines) + self.library_schema_end
+ test_schema = from_string(library_schema_string, ".mediawiki")
+
+ return test_schema
+
+ def load_schema1(self):
+ test_nodes = ["'''TestNode''' [This is a simple test node]\n",
+ " *TestNode2",
+ " *TestNode3",
+ " *TestNode4"
+ ]
+ return self._get_test_schema(test_nodes)
+
+ def load_schema2(self):
+ test_nodes = ["'''TestNode''' [This is a simple test node]\n",
+ " *TestNode2",
+ " **TestNode3",
+ " *TestNode5"
+ ]
+
+ return self._get_test_schema(test_nodes)
+
+ def test_find_matching_tags(self):
+ # create entries for schema1
+ schema1 = self.load_schema1()
+ schema2 = self.load_schema2()
+
+ result = find_matching_tags(schema1, schema2)
+ # Check if the result is correct
+ self.assertEqual(len(result[HedSectionKey.Tags]), 3)
+ self.assertIn("TestNode", result[HedSectionKey.Tags])
+ self.assertIn("TestNode2", result[HedSectionKey.Tags])
+ self.assertIn("TestNode3", result[HedSectionKey.Tags])
+ self.assertNotIn("TestNode4", result[HedSectionKey.Tags])
+ self.assertNotIn("TestNode5", result[HedSectionKey.Tags])
+
+ match_string = find_matching_tags(schema1, schema2, return_string=True)
+ self.assertIsInstance(match_string, str)
+ print(match_string)
+
+ def test_compare_schemas(self):
+ schema1 = self.load_schema1()
+ schema2 = self.load_schema2()
+
+ matches, not_in_schema1, not_in_schema2, unequal_entries = compare_schemas(schema1, schema2)
+
+ # Check if the result is correct
+ self.assertEqual(len(matches[HedSectionKey.Tags]), 2) # Three matches should be found
+ self.assertIn("TestNode", matches[HedSectionKey.Tags])
+ self.assertIn("TestNode2", matches[HedSectionKey.Tags])
+ self.assertNotIn("TestNode3", matches[HedSectionKey.Tags])
+
+ self.assertEqual(len(not_in_schema2[HedSectionKey.Tags]), 1) # One tag not in schema2
+ self.assertIn("TestNode4", not_in_schema2[HedSectionKey.Tags]) # "TestNode4" is not in schema2
+
+ self.assertEqual(len(not_in_schema1[HedSectionKey.Tags]), 1) # One tag not in schema1
+ self.assertIn("TestNode5", not_in_schema1[HedSectionKey.Tags]) # "TestNode5" is not in schema1
+
+ self.assertEqual(len(unequal_entries[HedSectionKey.Tags]), 1) # No unequal entries should be found
+ self.assertIn("TestNode3", unequal_entries[HedSectionKey.Tags])
+
+ def test_compare_differences(self):
+ schema1 = self.load_schema1()
+ schema2 = self.load_schema2()
+
+ not_in_schema1, not_in_schema2, unequal_entries = compare_differences(schema1, schema2)
+
+ self.assertEqual(len(not_in_schema2[HedSectionKey.Tags]), 1) # One tag not in schema2
+ self.assertIn("TestNode4", not_in_schema2[HedSectionKey.Tags]) # "TestNode4" is not in schema2
+
+ self.assertEqual(len(not_in_schema1[HedSectionKey.Tags]), 1) # One tag not in schema1
+ self.assertIn("TestNode5", not_in_schema1[HedSectionKey.Tags]) # "TestNode5" is not in schema1
+
+ self.assertEqual(len(unequal_entries[HedSectionKey.Tags]), 1) # No unequal entries should be found
+ self.assertIn("TestNode3", unequal_entries[HedSectionKey.Tags])
+
+ diff_string = compare_differences(schema1, schema2, return_string=True)
+ self.assertIsInstance(diff_string, str)
+ print(diff_string)
diff --git a/tests/schema/test_schema_compliance.py b/tests/schema/test_schema_compliance.py
index 467d34f7a..49f6afe02 100644
--- a/tests/schema/test_schema_compliance.py
+++ b/tests/schema/test_schema_compliance.py
@@ -1,8 +1,6 @@
import unittest
import os
-
-
from hed import schema