From 1ae46f30ab4d1dbe8882921179610a24f934722a Mon Sep 17 00:00:00 2001 From: IanCa Date: Fri, 21 Jul 2023 18:41:57 -0500 Subject: [PATCH] Tag Schema Entries now know their children Minor improvements to HedGroup.remove Minor doc string improvements schema.tags.root_tags now gets the root level tags --- hed/models/hed_group.py | 46 ++++++++++++++--------------- hed/models/hed_string.py | 2 +- hed/schema/hed_schema_entry.py | 4 +++ hed/schema/hed_schema_section.py | 4 ++- tests/schema/test_schema_compare.py | 4 +-- 5 files changed, 32 insertions(+), 28 deletions(-) diff --git a/hed/models/hed_group.py b/hed/models/hed_group.py index ba3fc287c..98e81df42 100644 --- a/hed/models/hed_group.py +++ b/hed/models/hed_group.py @@ -1,5 +1,6 @@ from hed.models.hed_tag import HedTag import copy +from typing import Iterable, Union class HedGroup: @@ -79,7 +80,7 @@ def replace(self, item_to_replace, new_contents): self._children[replace_index] = new_contents new_contents._parent = self - def remove(self, items_to_remove): + def remove(self, items_to_remove: Iterable[Union[HedTag, 'HedGroup']]): """ Remove any tags/groups in items_to_remove. Parameters: @@ -87,28 +88,28 @@ def remove(self, items_to_remove): Notes: - Any groups that become empty will also be pruned. + - If you pass a child and parent group, the child will also be removed from the parent. """ - all_groups = self.get_all_groups() - self._remove(items_to_remove, all_groups) - - def _remove(self, items_to_remove, all_groups): empty_groups = [] - for remove_child in items_to_remove: - for group in all_groups: - # only proceed if we have an EXACT match for this child - if any(remove_child is child for child in group._children): - if group._original_children is group._children: - group._original_children = group._children.copy() - - group._children = [child for child in group._children if child is not remove_child] - # If this was the last child, flag this group to be removed on a second pass - if not group._children and group is not self: - empty_groups.append(group) - break + # Filter out duplicates + items_to_remove = {id(item):item for item in items_to_remove}.values() + + for item in items_to_remove: + group = item._parent + if group._original_children is group._children: + group._original_children = group._children.copy() + + group._children.remove(item) + if not group._children and group is not self: + empty_groups.append(group) if empty_groups: self.remove(empty_groups) + # Do this last to avoid confusing typing + for item in items_to_remove: + item._parent = None + def __copy__(self): raise ValueError("Cannot make shallow copies of HedGroups") @@ -368,14 +369,15 @@ def find_tags(self, search_tags, recursive=False, include_groups=2): search_tags (container): A container of short_base_tags to locate recursive (bool): If true, also check subgroups. include_groups (0, 1 or 2): Specify return values. + If 0: return a list of the HedTags. + If 1: return a list of the HedGroups containing the HedTags. + If 2: return a list of tuples (HedTag, HedGroup) for the found tags. Returns: list: The contents of the list depends on the value of include_groups. Notes: - - If include_groups is 0, return a list of the HedTags. - - If include_groups is 1, return a list of the HedGroups containing the HedTags. - - If include_groups is 2, return a list of tuples (HedTag, HedGroup) for the found tags. + - This can only find identified tags. - By default, definition, def, def-expand, onset, and offset are identified, even without a schema. @@ -408,10 +410,6 @@ def find_wildcard_tags(self, search_tags, recursive=False, include_groups=2): Returns: list: The contents of the list depends on the value of include_groups. - - Notes: - - This can only find identified tags. - - By default, definition, def, def-expand, onset, and offset are identified, even without a schema. """ found_tags = [] if recursive: diff --git a/hed/models/hed_string.py b/hed/models/hed_string.py index 60c8a392a..328316868 100644 --- a/hed/models/hed_string.py +++ b/hed/models/hed_string.py @@ -103,7 +103,7 @@ def copy(self): """ Return a deep copy of this string. Returns: - HedGroup: The copied group. + HedString: The copied group. """ return_copy = copy.deepcopy(self) diff --git a/hed/schema/hed_schema_entry.py b/hed/schema/hed_schema_entry.py index 80a1c3421..408b8d984 100644 --- a/hed/schema/hed_schema_entry.py +++ b/hed/schema/hed_schema_entry.py @@ -210,6 +210,8 @@ def __init__(self, *args, **kwargs): self.tag_terms = tuple() # During setup, it's better to have attributes shadow inherited before getting its own copy later. self.inherited_attributes = self.attributes + # Descendent tags below this one + self.children = {} def has_attribute(self, attribute, return_value=False): """ Returns th existence or value of an attribute in this entry. @@ -344,6 +346,8 @@ def finalize_entry(self, schema): if parent_name: parent_tag = schema._get_tag_entry(parent_name) self._parent_tag = parent_tag + if self._parent_tag: + self._parent_tag.children[self.short_tag_name] = self self.takes_value_child_entry = schema._get_tag_entry(self.name + "/#") self.tag_terms = tuple(self.long_tag_name.lower().split("/")) diff --git a/hed/schema/hed_schema_section.py b/hed/schema/hed_schema_section.py index 9b6ba657e..d312a29ce 100644 --- a/hed/schema/hed_schema_section.py +++ b/hed/schema/hed_schema_section.py @@ -165,6 +165,7 @@ def __init__(self, *args, case_sensitive=False, **kwargs): # This dict contains all forms of all tags. The .all_names variable has ONLY the long forms. self.long_form_tags = {} self.inheritable_attributes = {} + self.root_tags = {} @staticmethod def _get_tag_forms(name): @@ -267,5 +268,6 @@ def _finalize_section(self, hed_schema): if extension_allowed_node: split_list[extension_allowed_node:] = sorted(split_list[extension_allowed_node:], key=lambda x: x[0].long_tag_name) self.all_entries = [subitem for tag_list in split_list for subitem in tag_list] - super()._finalize_section(hed_schema) + super()._finalize_section(hed_schema) + self.root_tags = {tag.short_tag_name:tag for tag in self.all_entries if not tag._parent_tag} diff --git a/tests/schema/test_schema_compare.py b/tests/schema/test_schema_compare.py index b20c65ba6..829c4bad5 100644 --- a/tests/schema/test_schema_compare.py +++ b/tests/schema/test_schema_compare.py @@ -62,7 +62,7 @@ def test_find_matching_tags(self): match_string = find_matching_tags(schema1, schema2, return_string=True) self.assertIsInstance(match_string, str) - print(match_string) + # print(match_string) def test_compare_schemas(self): schema1 = self.load_schema1() @@ -102,4 +102,4 @@ def test_compare_differences(self): diff_string = compare_differences(schema1, schema2, return_string=True) self.assertIsInstance(diff_string, str) - print(diff_string) + # print(diff_string)