diff --git a/ofrak_core/ofrak/model/resource_model.py b/ofrak_core/ofrak/model/resource_model.py index e36c98b33..502b38b30 100644 --- a/ofrak_core/ofrak/model/resource_model.py +++ b/ofrak_core/ofrak/model/resource_model.py @@ -693,6 +693,9 @@ def __init__( self.is_deleted = False self._diff: Optional[ResourceModelDiff] = None + def __hash__(self): + return self.id.__hash__() + @property def diff(self): if not self._diff: diff --git a/ofrak_core/ofrak/resource.py b/ofrak_core/ofrak/resource.py index 12dbf0e28..c9af13318 100644 --- a/ofrak_core/ofrak/resource.py +++ b/ofrak_core/ofrak/resource.py @@ -296,12 +296,12 @@ async def save(self): self._resource_view_context, ) - def _save( - self, - resources_to_delete: List[bytes], - patches_to_apply: List[DataPatch], - resources_to_update: List[MutableResourceModel], - ): + def _save(self) -> Tuple[List[bytes], List[DataPatch], List[MutableResourceModel]]: + + resources_to_delete: List[bytes] = [] + patches_to_apply: List[DataPatch] = [] + resources_to_update: List[MutableResourceModel] = [] + if self._resource.is_deleted: resources_to_delete.append(self._resource.id) elif self._resource.is_modified: @@ -316,6 +316,8 @@ def _save( resources_to_update.append(self._resource) modification_tracker.data_patches.clear() + return resources_to_delete, patches_to_apply, resources_to_update + async def _fetch(self, resource: MutableResourceModel): """ Update the local model with the latest version from the resource service. This will fail @@ -1543,11 +1545,11 @@ async def save_resources( resources_to_update: List[MutableResourceModel] = [] for resource in resources: - resource._save( - resources_to_delete, - patches_to_apply, - resources_to_update, - ) + _resources_to_delete, _patches_to_apply, _resources_to_update = resource._save() + + resources_to_delete.extend(_resources_to_delete) + patches_to_apply.extend(_patches_to_apply) + resources_to_update.extend(_resources_to_update) deleted_descendants = await resource_service.delete_resources(resources_to_delete) data_ids_to_delete = [ @@ -1555,7 +1557,9 @@ async def save_resources( ] await data_service.delete_models(data_ids_to_delete) patch_results = await data_service.apply_patches(patches_to_apply) - await dependency_handler.handle_post_patch_dependencies(patch_results) + resources_to_update.extend( + await dependency_handler.handle_post_patch_dependencies(patch_results) + ) diffs = [] updated_ids = [] for resource_m in resources_to_update: diff --git a/ofrak_core/ofrak/service/dependency_handler.py b/ofrak_core/ofrak/service/dependency_handler.py index 440fdcbd8..8ea8c7577 100644 --- a/ofrak_core/ofrak/service/dependency_handler.py +++ b/ofrak_core/ofrak/service/dependency_handler.py @@ -60,7 +60,11 @@ async def map_data_ids_to_resources( return resources_by_data_id - async def handle_post_patch_dependencies(self, patch_results: List[DataPatchesResult]): + async def handle_post_patch_dependencies( + self, patch_results: List[DataPatchesResult] + ) -> Set[MutableResourceModel]: + modified_resources = set() + # Create look up maps for resources and dependencies resources_by_data_id = await self.map_data_ids_to_resources( patch_result.data_id for patch_result in patch_results @@ -78,6 +82,7 @@ async def handle_post_patch_dependencies(self, patch_results: List[DataPatchesRe resource_m = resources_by_data_id[data_patch_result.data_id] data_m = models_by_data_id[data_patch_result.data_id] resource_m.add_attributes(Data(data_m.range.start, data_m.range.length())) + modified_resources.add(resource_m) unhandled_dependencies: Set[ResourceAttributeDependency] = set() # Figure out which components results must be invalidated based on data changes @@ -109,15 +114,17 @@ async def handle_post_patch_dependencies(self, patch_results: List[DataPatchesRe break for removed_data_dependency in removed_data_dependencies: resource_m.remove_dependency(removed_data_dependency) + modified_resources.add(resource_m) # Recursively invalidate component results based on other components that were invalidated handled_dependencies: Set[ResourceAttributeDependency] = set() await self._invalidate_dependencies( - handled_dependencies, - unhandled_dependencies, + handled_dependencies, unhandled_dependencies, modified_resources ) + return modified_resources + def create_component_dependencies( self, component_id: bytes, @@ -224,6 +231,7 @@ async def _invalidate_dependencies( self, handled_dependencies: Set[ResourceAttributeDependency], unhandled_dependencies: Set[ResourceAttributeDependency], + resources_modified: Set[MutableResourceModel], ): """ Invalidate the unhandled resource attribute dependencies. @@ -282,6 +290,7 @@ async def _invalidate_dependencies( if resource_m.get_component_id_by_attributes(dependency.attributes): resource_m.remove_component(dependency.component_id, dependency.attributes) self._component_context.mark_resource_modified(resource_m.id) + resources_modified.add(resource_m) # Find other dependencies to invalidate due to the invalidation of the attributes invalidated_dependencies = set() @@ -301,11 +310,13 @@ async def _invalidate_dependencies( for invalidated_dependency in invalidated_dependencies: resource_m.remove_dependency(invalidated_dependency) self._component_context.mark_resource_modified(resource_m.id) + resources_modified.add(resource_m) next_unhandled_dependencies.update(invalidated_dependencies) await self._invalidate_dependencies( handled_dependencies, next_unhandled_dependencies, + resources_modified, )