Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hed/errors/error_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def _add_context_to_errors(error_object, error_context_to_add):
@staticmethod
def _create_error_object(error_type, base_message, severity, **kwargs):
if severity == ErrorSeverity.ERROR:
error_prefix = "ERROR: "
error_prefix = f"{error_type}: "
else:
error_prefix = "WARNING: "
error_prefix = f"{error_type} (Warning): "
error_message = error_prefix + base_message
error_object = {'code': error_type,
'message': error_message,
Expand Down
71 changes: 58 additions & 13 deletions hed/schema/schema_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,19 @@
# This is still in design, means header attributes, epilogue, and prologue
MiscSection = "misc"

SectionEntryNames = {
HedSectionKey.Tags: "Tag",
HedSectionKey.Units: "Unit",
HedSectionKey.UnitClasses: "Unit Class",
HedSectionKey.ValueClasses: "Value Class",
HedSectionKey.UnitModifiers: "Unit Modifier",
HedSectionKey.Properties: "Property",
HedSectionKey.Attributes: "Attribute",

def find_matching_tags(schema1, schema2, output='default', sections=(HedSectionKey.Tags,)):
}


def find_matching_tags(schema1, schema2, output='raw', sections=(HedSectionKey.Tags,)):
"""
Compare the tags in two library schemas. This finds tags with the same term.

Expand All @@ -19,16 +30,25 @@ def find_matching_tags(schema1, schema2, output='default', sections=(HedSectionK
If None, checks all sections including header, prologue, and epilogue.

Returns:
dict or str: A dictionary containing matching entries in the Tags section of both schemas.
dict, json style dict, or str: A dictionary containing matching entries in the Tags section of both schemas.
"""
matches, _, _, unequal_entries = compare_schemas(schema1, schema2, sections=sections)

for section_key, section_dict in matches.items():
section_dict.update(unequal_entries[section_key])

if output == 'string':
return "\n".join([_pretty_print_diff_all(entries, prompt="Found matching node ")
for entries in matches.values()])
final_string = ""
if sections is None:
sections = HedSectionKey
for section_key in sections:
type_name = SectionEntryNames[section_key]
entries = matches[section_key]
if not entries:
continue
final_string += f"{type_name} differences:\n"
final_string += _pretty_print_diff_all(entries, type_name=type_name) + "\n"
return final_string
elif output == 'dict':
output_dict = {}
for section_name, section_entries in matches.items():
Expand Down Expand Up @@ -71,11 +91,25 @@ def compare_differences(schema1, schema2, output='raw', attribute_filter=None, s
_, not_in_1, not_in_2, unequal_entries = compare_schemas(schema1, schema2, attribute_filter=attribute_filter,
sections=sections)

if sections is None:
sections = HedSectionKey

if output == 'string':
str1 = "\n".join([_pretty_print_diff_all(entries) for entries in unequal_entries.values()]) + "\n"
str2 = "\n".join([_pretty_print_missing_all(entries, "Schema1") for entries in not_in_1.values()]) + "\n"
str3 = "\n".join([_pretty_print_missing_all(entries, "Schema2") for entries in not_in_2.values()])
return str1 + str2 + str3
final_string = ""
for section_key in sections:
val1, val2, val3 = unequal_entries[section_key], not_in_1[section_key], not_in_2[section_key]
type_name = SectionEntryNames[section_key]
if val1 or val2 or val3:
if final_string:
final_string += "\n\n"
final_string += f"{type_name} differences:\n"
if val1:
final_string += _pretty_print_diff_all(val1, type_name=type_name) + "\n"
if val2:
final_string += _pretty_print_missing_all(val2, "Schema1", type_name) + "\n"
if val3:
final_string += _pretty_print_missing_all(val3, "Schema2", type_name) + "\n"
return final_string
elif output == 'dict':
# todo: clean this part up
output_dict = {}
Expand Down Expand Up @@ -286,37 +320,48 @@ def _pretty_print_diff_entry(entry1, entry2):
return diff_lines


def _pretty_print_diff_all(entries, prompt="Differences for "):
def _pretty_print_diff_all(entries, type_name=""):
"""
Formats the differences between pairs of HedSchemaEntry objects.

Parameters:
entries (dict): A dictionary where each key maps to a pair of HedSchemaEntry objects.
prompt(str): The prompt for each entry
type_name(str): The type to identify this as, such as Tag
Returns:
diff_string(str): The differences found in the dict
"""
output = []
if not type_name.endswith(" "):
type_name += " "
if not entries:
return ""
for key, (entry1, entry2) in entries.items():
output.append(f"{prompt}'{key}':")
output.append(f"{type_name}'{key}':")
output += _pretty_print_diff_entry(entry1, entry2)
output.append("")

return "\n".join(output)


def _pretty_print_missing_all(entries, schema_name):
def _pretty_print_missing_all(entries, schema_name, type_name):
"""
Formats the missing entries from schema_name.

Parameters:
entries (dict): A dictionary where each key maps to a pair of HedSchemaEntry objects.
schema_name(str): The name these entries are missing from
type_name(str): The type to identify this as, such as Tag
Returns:
diff_string(str): The differences found in the dict
"""
output = []
if not entries:
return ""
if not type_name.endswith(" "):
type_name += " "
for key, entry in entries.items():
output.append(f"'{key}' not in '{schema_name}':")
output.append(f"{type_name}'{key}' not in '{schema_name}':")
output += _pretty_print_entry(entry)
output.append("")

return "\n".join(output)