From 05c3eb868f050a1ee98e7ec01c5e7a0408eb9915 Mon Sep 17 00:00:00 2001 From: Martin Varga Date: Tue, 17 Jun 2025 11:00:59 +0200 Subject: [PATCH] Refactor upload changes handling move relevant validations into upload schema use upload schema only in push init while project file object is used for actual file manipulations remove version context and redundant dataclasses for upload schema --- server/mergin/sync/commands.py | 4 +- server/mergin/sync/files.py | 196 +++++++++--- server/mergin/sync/models.py | 110 +++---- server/mergin/sync/public_api_controller.py | 299 +++++++++--------- server/mergin/sync/schemas.py | 2 +- server/mergin/sync/storages/disk.py | 8 +- server/mergin/tests/fixtures.py | 7 +- server/mergin/tests/test_db_hooks.py | 9 +- .../mergin/tests/test_project_controller.py | 19 +- server/mergin/tests/utils.py | 38 +-- 10 files changed, 377 insertions(+), 315 deletions(-) diff --git a/server/mergin/sync/commands.py b/server/mergin/sync/commands.py index 97e85981..4ec898cf 100644 --- a/server/mergin/sync/commands.py +++ b/server/mergin/sync/commands.py @@ -9,7 +9,6 @@ from datetime import datetime from flask import Flask, current_app -from .files import UploadChanges from ..app import db from .models import Project, ProjectVersion from .utils import split_project_path @@ -52,8 +51,7 @@ def create(name, namespace, username): # pylint: disable=W0612 p = Project(**project_params) p.updated = datetime.utcnow() db.session.add(p) - changes = UploadChanges(added=[], updated=[], removed=[]) - pv = ProjectVersion(p, 0, user.id, changes, "127.0.0.1") + pv = ProjectVersion(p, 0, user.id, [], "127.0.0.1") pv.project = p db.session.commit() os.makedirs(p.storage.project_dir, exist_ok=True) diff --git a/server/mergin/sync/files.py b/server/mergin/sync/files.py index 12b30afe..6073b0fb 100644 --- a/server/mergin/sync/files.py +++ b/server/mergin/sync/files.py @@ -3,14 +3,29 @@ # SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-MerginMaps-Commercial import datetime import os +import uuid from dataclasses import dataclass -from typing import Optional, List -from marshmallow import fields, EXCLUDE, pre_load, post_load, post_dump +from enum import Enum +from flask import current_app +from marshmallow import ValidationError, fields, EXCLUDE, post_dump, validates_schema from pathvalidate import sanitize_filename +from typing import Optional, List +from .utils import is_file_name_blacklisted, is_qgis, is_versioned_file from ..app import DateTimeWithZ, ma +class PushChangeType(Enum): + CREATE = "create" + UPDATE = "update" + DELETE = "delete" + UPDATE_DIFF = "update_diff" + + @classmethod + def values(cls): + return [member.value for member in cls.__members__.values()] + + def mergin_secure_filename(filename: str) -> str: """Generate secure filename for given file""" filename = os.path.normpath(filename) @@ -24,94 +39,181 @@ def mergin_secure_filename(filename: str) -> str: @dataclass class File: - """Base class for every file object""" + """Base class for every file object, either intended to upload or already existing in project""" path: str checksum: str size: int - location: str def is_valid_gpkg(self): """Check if diff file is valid""" return self.size != 0 +@dataclass +class ProjectDiffFile(File): + """Metadata for geodiff diff file (aka. changeset) associated with geopackage""" + + # location where file is actually stored + location: str + + @dataclass class ProjectFile(File): - """Project file metadata including metadata for diff file""" + """Project file metadata including metadata for diff file and location where it is stored""" # metadata for gpkg diff file - diff: Optional[File] + diff: Optional[ProjectDiffFile] # deprecated attribute kept for public API compatibility mtime: Optional[datetime.datetime] + # location where file is actually stored + location: str @dataclass -class UploadFile(File): - """File to be uploaded coming from client push process""" - - # determined by client - chunks: Optional[List[str]] - diff: Optional[File] - +class ProjectFileChange(ProjectFile): + """Metadata of changed file in project version. + + This item is saved into database into file_history. + """ + + change: PushChangeType + + +def files_changes_from_upload(changes: dict, version: int) -> List["ProjectFileChange"]: + """Create a list of version file changes from upload changes dictionary used by public API. + + It flattens changes dict and adds change type to each item. Also generates location for each file. + """ + secure_filenames = [] + version_changes = [] + version = "v" + str(version) + for key in ("added", "updated", "removed"): + for item in changes.get(key, []): + location = os.path.join(version, mergin_secure_filename(item["path"])) + diff = None + + # make sure we have unique location for each file + if location in secure_filenames: + filename, file_extension = os.path.splitext(location) + location = filename + f".{str(uuid.uuid4())}" + file_extension + + secure_filenames.append(location) + + if key == "removed": + change = PushChangeType.DELETE + location = None + elif key == "added": + change = PushChangeType.CREATE + else: + change = PushChangeType.UPDATE + if item.get("diff"): + change = PushChangeType.UPDATE_DIFF + diff_location = os.path.join( + version, mergin_secure_filename(item["diff"]["path"]) + ) + if diff_location in secure_filenames: + filename, file_extension = os.path.splitext(diff_location) + diff_location = ( + filename + f".{str(uuid.uuid4())}" + file_extension + ) + + secure_filenames.append(diff_location) + diff = ProjectDiffFile( + path=item["diff"]["path"], + checksum=item["diff"]["checksum"], + size=item["diff"]["size"], + location=diff_location, + ) + + file_change = ProjectFileChange( + path=item["path"], + checksum=item["checksum"], + size=item["size"], + mtime=None, + change=change, + location=location, + diff=diff, + ) + version_changes.append(file_change) -@dataclass -class UploadChanges: - added: List[UploadFile] - updated: List[UploadFile] - removed: List[UploadFile] + return version_changes class FileSchema(ma.Schema): path = fields.String() size = fields.Integer() checksum = fields.String() - location = fields.String(load_default="", load_only=True) class Meta: unknown = EXCLUDE - @post_load - def create_obj(self, data, **kwargs): - return File(**data) - class UploadFileSchema(FileSchema): chunks = fields.List(fields.String(), load_default=[]) diff = fields.Nested(FileSchema(), many=False, load_default=None) - @pre_load - def pre_load(self, data, **kwargs): - # add future location based on context version - version = f"v{self.context.get('version')}" - if not data.get("location"): - data["location"] = os.path.join( - version, mergin_secure_filename(data["path"]) - ) - if data.get("diff") and not data.get("diff").get("location"): - data["diff"]["location"] = os.path.join( - version, mergin_secure_filename(data["diff"]["path"]) - ) - return data - - @post_load - def create_obj(self, data, **kwargs): - return UploadFile(**data) - class ChangesSchema(ma.Schema): """Schema for upload changes""" - added = fields.List(fields.Nested(UploadFileSchema()), load_default=[]) - updated = fields.List(fields.Nested(UploadFileSchema()), load_default=[]) - removed = fields.List(fields.Nested(UploadFileSchema()), load_default=[]) + added = fields.List( + fields.Nested(UploadFileSchema()), load_default=[], dump_default=[] + ) + updated = fields.List( + fields.Nested(UploadFileSchema()), load_default=[], dump_default=[] + ) + removed = fields.List( + fields.Nested(UploadFileSchema()), load_default=[], dump_default=[] + ) + is_blocking = fields.Method("_is_blocking") class Meta: unknown = EXCLUDE - @post_load - def create_obj(self, data, **kwargs): - return UploadChanges(**data) + def _is_blocking(self, obj) -> bool: + """Check if changes would be blocking.""" + # let's mark upload as non-blocking only if there are new non-spatial data added (e.g. photos) + return bool( + len(obj.get("updated", [])) + or len(obj.get("removed", [])) + or any( + is_qgis(f["path"]) or is_versioned_file(f["path"]) + for f in obj.get("added", []) + ) + ) + + @post_dump + def remove_blacklisted_files(self, data, **kwargs): + """Files which are blacklisted are not allowed to be uploaded and are simple ignored.""" + for key in ("added", "updated", "removed"): + data[key] = [ + f + for f in data[key] + if not is_file_name_blacklisted( + f["path"], current_app.config["BLACKLIST"] + ) + ] + return data + + @validates_schema + def validate(self, data, **kwargs): + """Basic consistency validations for upload metadata""" + changes_files = [ + f["path"] for f in data["added"] + data["updated"] + data["removed"] + ] + + if len(changes_files) == 0: + raise ValidationError("No changes") + + # changes' files must be unique + if len(set(changes_files)) != len(changes_files): + raise ValidationError("Not unique changes") + + # check if all .gpkg file are valid + for file in data["added"] + data["updated"]: + if is_versioned_file(file["path"]) and file["size"] == 0: + raise ValidationError("File is not valid") class ProjectFileSchema(FileSchema): diff --git a/server/mergin/sync/models.py b/server/mergin/sync/models.py index 5992726a..be57e4d4 100644 --- a/server/mergin/sync/models.py +++ b/server/mergin/sync/models.py @@ -23,9 +23,11 @@ from .files import ( File, - UploadChanges, + ProjectDiffFile, + ProjectFileChange, ChangesSchema, ProjectFile, + PushChangeType, ) from .interfaces import WorkspaceRole from .storages.disk import move_to_tmp @@ -38,17 +40,6 @@ project_access_granted = signal("project_access_granted") -class PushChangeType(Enum): - CREATE = "create" - UPDATE = "update" - DELETE = "delete" - UPDATE_DIFF = "update_diff" - - @classmethod - def values(cls): - return [member.value for member in cls.__members__.values()] - - class Project(db.Model): id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = db.Column(db.String, index=True) @@ -181,7 +172,7 @@ def files(self) -> List[ProjectFile]: checksum=row.checksum, location=row.location, mtime=row.mtime, - diff=File(**row.diff) if row.diff else None, + diff=ProjectDiffFile(**row.diff) if row.diff else None, ) for row in db.session.execute(query, params).fetchall() ] @@ -504,9 +495,9 @@ def path(self) -> str: return self.file.path @property - def diff_file(self) -> Optional[File]: + def diff_file(self) -> Optional[ProjectDiffFile]: if self.diff: - return File(**self.diff) + return ProjectDiffFile(**self.diff) @property def mtime(self) -> datetime: @@ -705,7 +696,7 @@ def __init__( project: Project, name: int, author_id: int, - changes: UploadChanges, + changes: List[ProjectFileChange], ip: str, user_agent: str = None, device_id: str = None, @@ -725,9 +716,7 @@ def __init__( ).all() } - changed_files_paths = [ - f.path for f in changes.updated + changes.removed + changes.added - ] + changed_files_paths = set(change.path for change in changes) existing_files_map = { f.path: f for f in ProjectFilePath.query.filter_by(project_id=self.project_id) @@ -735,46 +724,32 @@ def __init__( .all() } - for key in ( - ("added", PushChangeType.CREATE), - ("updated", PushChangeType.UPDATE), - ("removed", PushChangeType.DELETE), - ): - change_attr = key[0] - change_type = key[1] - - for upload_file in getattr(changes, change_attr): - is_diff_change = ( - change_type is PushChangeType.UPDATE - and upload_file.diff is not None - ) - - file = existing_files_map.get( - upload_file.path, ProjectFilePath(self.project_id, upload_file.path) - ) - fh = FileHistory( - file=file, - size=upload_file.size, - checksum=upload_file.checksum, - location=upload_file.location, - diff=( - asdict(upload_file.diff) - if (is_diff_change and upload_file.diff) - else null() - ), - change=( - PushChangeType.UPDATE_DIFF if is_diff_change else change_type - ), - ) - fh.version = self - fh.project_version_name = self.name - db.session.add(fh) - db.session.flush() + for item in changes: + # get existing DB file reference or create a new one (for added files) + db_file = existing_files_map.get( + item.path, ProjectFilePath(self.project_id, item.path) + ) + fh = FileHistory( + file=db_file, + size=item.size, + checksum=item.checksum, + location=item.location, + diff=( + asdict(item.diff) + if (item.change is PushChangeType.UPDATE_DIFF and item.diff) + else null() + ), + change=item.change, + ) + fh.version = self + fh.project_version_name = self.name + db.session.add(fh) + db.session.flush() - if change_type is PushChangeType.DELETE: - latest_files_map.pop(fh.path, None) - else: - latest_files_map[fh.path] = fh.id + if item.change is PushChangeType.DELETE: + latest_files_map.pop(fh.path, None) + else: + latest_files_map[fh.path] = fh.id # update cached values in project and push to transaction buffer so that self.files is up-to-date self.project.latest_project_files.file_history_ids = latest_files_map.values() @@ -909,7 +884,7 @@ def files(self) -> List[ProjectFile]: checksum=row.checksum, location=row.location, mtime=row.mtime, - diff=File(**row.diff) if row.diff else None, + diff=ProjectDiffFile(**row.diff) if row.diff else None, ) for row in result ] @@ -1029,12 +1004,13 @@ class Upload(db.Model): ), ) - def __init__(self, project: Project, changes: UploadChanges, user_id: int): + def __init__(self, project: Project, changes: dict, user_id: int): + upload_changes = ChangesSchema().dump(changes) self.id = str(uuid.uuid4()) self.project_id = project.id - self.changes = ChangesSchema().dump(changes) self.user_id = user_id - self.blocking = self.is_blocking(changes) + self.blocking = upload_changes.pop("is_blocking") + self.changes = upload_changes @property def upload_dir(self): @@ -1059,16 +1035,6 @@ def clear(self): db.session.delete(self) db.session.commit() - @staticmethod - def is_blocking(changes: UploadChanges) -> bool: - """Check if changes would be blocking.""" - # let's mark upload as non-blocking only if there are new non-spatial data added (e.g. photos) - return bool( - len(changes.updated) - or len(changes.removed) - or any(is_qgis(f.path) or is_versioned_file(f.path) for f in changes.added) - ) - def file_already_in_upload(self, path) -> bool: """Check if file is not already as new added file""" return any(f["path"] == path for f in self.changes["added"]) diff --git a/server/mergin/sync/public_api_controller.py b/server/mergin/sync/public_api_controller.py index 88c47bf6..e9fb4a76 100644 --- a/server/mergin/sync/public_api_controller.py +++ b/server/mergin/sync/public_api_controller.py @@ -10,8 +10,8 @@ from dataclasses import asdict from typing import Dict from urllib.parse import quote -import uuid from datetime import datetime +from marshmallow import ValidationError import gevent import psycopg2 @@ -50,11 +50,13 @@ ProjectRole, ) from .files import ( - UploadChanges, ChangesSchema, - UploadFileSchema, + ProjectDiffFile, + ProjectFile, + ProjectFileChange, ProjectFileSchema, - FileSchema, + files_changes_from_upload, + mergin_secure_filename, ) from .schemas import ( ProjectSchema, @@ -239,15 +241,24 @@ def add_project(namespace): # noqa: E501 .first_or_404() ) version_name = 1 - files = UploadFileSchema(context={"version": 1}, many=True).load( - FileSchema(exclude=("location",), many=True).dump(template.files) - ) - changes = UploadChanges(added=files, updated=[], removed=[]) + file_changes = [] + for file in template.files: + file_changes.append( + ProjectFileChange( + file.path, + file.checksum, + file.size, + diff=None, + mtime=None, + location=os.path.join("v1", mergin_secure_filename(file.path)), + change=PushChangeType.CREATE, + ) + ) else: template = None version_name = 0 - changes = UploadChanges(added=[], updated=[], removed=[]) + file_changes = [] try: p.storage.initialize(template_project=template) @@ -258,7 +269,7 @@ def add_project(namespace): # noqa: E501 p, version_name, current_user.id, - changes, + file_changes, get_ip(request), get_user_agent(request), get_device_id(request), @@ -758,13 +769,15 @@ def project_push(namespace, project_name): if not pv and version != 0: abort(400, "First push should be with v0") - if all(len(changes[key]) == 0 for key in changes.keys()): - abort(400, "No changes") - - upload_changes = ChangesSchema(context={"version": pv.name + 1}).load(changes) + try: + ChangesSchema().validate(changes) + upload_changes = ChangesSchema().dump(changes) + except ValidationError as err: + msg = err.messages[0] if type(err.messages) == list else "Invalid input data" + abort(400, msg) # reject upload early if there is another blocking upload already running - if Upload.is_blocking(upload_changes): + if upload_changes["is_blocking"]: pending_upload = Upload.query.filter_by( project_id=project.id, blocking=True ).first() @@ -772,77 +785,45 @@ def project_push(namespace, project_name): abort(400, "Another process is running. Please try later.") current_files = set(file.path for file in project.files) - for item in upload_changes.added: + for item in upload_changes["added"]: # check if same file is not already uploaded or in pending upload - if item.path in current_files: - abort(400, f"File {item.path} has been already uploaded") + item_path = item["path"] + if item_path in current_files: + abort(400, f"File {item_path} has been already uploaded") for upload in project.uploads.all(): if not upload.is_active(): continue - if upload.file_already_in_upload(item.path): - abort(400, f"File {item.path} is already in other upload") + if upload.file_already_in_upload(item_path): + abort(400, f"File {item_path} is already in other upload") - if not is_valid_path(item.path): + if not is_valid_path(item_path): abort( 400, - f"Unsupported file name detected: {item.path}. Please remove the invalid characters.", + f"Unsupported file name detected: {item_path}. Please remove the invalid characters.", ) - if not is_supported_extension(item.path): + if not is_supported_extension(item_path): abort( 400, - f"Unsupported file type detected: {item.path}. " + f"Unsupported file type detected: {item_path}. " f"Please remove the file or try compressing it into a ZIP file before uploading.", ) # check consistency of changes if not set( - file.path for file in upload_changes.updated + upload_changes.removed + file["path"] for file in upload_changes["updated"] + upload_changes["removed"] ).issubset(current_files): abort(400, "Update or remove changes contain files that are not in project") - # changes' files must be unique - changes_files = [ - f.path - for f in upload_changes.added + upload_changes.updated + upload_changes.removed - ] - if len(set(changes_files)) != len(changes_files): - abort(400, "Not unique changes") - - sanitized_files = [] - blacklisted_files = [] - for f in upload_changes.added + upload_changes.updated + upload_changes.removed: - # check if .gpkg file is valid - if is_versioned_file(f.path): - if not f.is_valid_gpkg(): - abort(400, f"File {f.path} is not valid") - if is_file_name_blacklisted(f.path, current_app.config["BLACKLIST"]): - blacklisted_files.append(f.path) - # all file need to be unique after sanitized - if f.location in sanitized_files: - filename, file_extension = os.path.splitext(f.location) - f.location = filename + f".{str(uuid.uuid4())}" + file_extension - sanitized_files.append(f.location) - if f.diff: - if f.diff.location in sanitized_files: - filename, file_extension = os.path.splitext(f.diff.location) - f.diff.location = filename + f".{str(uuid.uuid4())}" + file_extension - sanitized_files.append(f.diff.location) - - # remove blacklisted files from changes - for key in upload_changes.__dict__.keys(): - new_value = [ - f for f in getattr(upload_changes, key) if f.path not in blacklisted_files - ] - setattr(upload_changes, key, new_value) - # Check user data limit - updates = [f.path for f in upload_changes.updated] + updates = [f["path"] for f in upload_changes["updated"]] updated_files = list(filter(lambda i: i.path in updates, project.files)) additional_disk_usage = ( - sum(file.size for file in upload_changes.added + upload_changes.updated) + sum( + file["size"] for file in upload_changes["added"] + upload_changes["updated"] + ) - sum(file.size for file in updated_files) - - sum(file.size for file in upload_changes.removed) + - sum(file["size"] for file in upload_changes["removed"]) ) current_usage = ws.disk_usage() @@ -899,12 +880,13 @@ def project_push(namespace, project_name): next_version = version + 1 user_agent = get_user_agent(request) device_id = get_device_id(request) + file_changes = files_changes_from_upload(upload.changes, version=next_version) try: pv = ProjectVersion( project, next_version, current_user.id, - upload_changes, + file_changes, get_ip(request), user_agent, device_id, @@ -950,29 +932,27 @@ def chunk_upload(transaction_id, chunk_id): """ upload, upload_dir = get_upload(transaction_id) request.view_args["project"] = upload.project - upload_changes = ChangesSchema( - context={"version": upload.project.latest_version + 1} - ).load(upload.changes) - for f in upload_changes.added + upload_changes.updated: - if chunk_id in f.chunks: - dest = os.path.join(upload_dir, "chunks", chunk_id) - lockfile = os.path.join(upload_dir, "lockfile") - with Toucher(lockfile, 30): - try: - # we could have used request.data here, but it could eventually cause OOM issue - save_to_file( - request.stream, dest, current_app.config["MAX_CHUNK_SIZE"] - ) - except IOError: - move_to_tmp(dest, transaction_id) - abort(400, "Too big chunk") - if os.path.exists(dest): - checksum = generate_checksum(dest) - size = os.path.getsize(dest) - return jsonify({"checksum": checksum, "size": size}), 200 - else: - abort(400, "Upload was probably canceled") - abort(404) + chunks = [] + for file in upload.changes["added"] + upload.changes["updated"]: + chunks += file.get("chunks", []) + + if chunk_id not in chunks: + abort(404) + + dest = os.path.join(upload_dir, "chunks", chunk_id) + with Toucher(upload.lockfile, 30): + try: + # we could have used request.data here, but it could eventually cause OOM issue + save_to_file(request.stream, dest, current_app.config["MAX_CHUNK_SIZE"]) + except IOError: + move_to_tmp(dest, transaction_id) + abort(400, "Too big chunk") + if os.path.exists(dest): + checksum = generate_checksum(dest) + size = os.path.getsize(dest) + return jsonify({"checksum": checksum, "size": size}), 200 + else: + abort(400, "Upload was probably canceled") @auth_required @@ -1003,24 +983,34 @@ def push_finish(transaction_id): project_path = get_project_path(project) next_version = project.latest_version + 1 v_next_version = ProjectVersion.to_v_name(next_version) - changes = ChangesSchema(context={"version": next_version}).load(upload.changes) + try: + upload_changes = ChangesSchema().load(upload.changes) + except ValidationError as err: + msg = err.messages[0] if type(err.messages) == list else "Invalid input data" + abort(422, msg) + + file_changes = files_changes_from_upload(upload_changes, next_version) + chunks_map = { + f["path"]: f["chunks"] + for f in upload_changes["added"] + upload_changes["updated"] + } corrupted_files = [] - for f in changes.added + changes.updated: - if f.diff is not None: - dest_file = os.path.join(upload_dir, "files", f.diff.location) - expected_size = f.diff.size - else: - dest_file = os.path.join(upload_dir, "files", f.location) - expected_size = f.size + # Concatenate chunks into single file + for f in file_changes: + if f.change == PushChangeType.DELETE: + continue - # Concatenate chunks into single file + f_path = ( + f.diff.location if f.change == PushChangeType.UPDATE_DIFF else f.location + ) + temporary_location = os.path.join(upload_dir, "files", f_path) # TODO we need to move this elsewhere since it can fail for large files (and slow FS) - os.makedirs(os.path.dirname(dest_file), exist_ok=True) - with open(dest_file, "wb") as dest: + os.makedirs(os.path.dirname(temporary_location), exist_ok=True) + with open(temporary_location, "wb") as dest: try: - for chunk_id in f.chunks: + for chunk_id in chunks_map.get(f.path, []): sleep(0) # to unblock greenlet chunk_file = os.path.join(upload_dir, "chunks", chunk_id) with open(chunk_file, "rb") as src: @@ -1035,24 +1025,27 @@ def push_finish(transaction_id): ) corrupted_files.append(f.path) continue - if not is_supported_type(dest_file): - logging.info(f"Rejecting blacklisted file: {dest_file}") + if not is_supported_type(temporary_location): + logging.info(f"Rejecting blacklisted file: {temporary_location}") abort( 400, f"Unsupported file type detected: {f.path}. " f"Please remove the file or try compressing it into a ZIP file before uploading.", ) - if expected_size != os.path.getsize(dest_file): + # check if .gpkg file is valid + if is_versioned_file(temporary_location) and not f.is_valid_gpkg(): + corrupted_files.append(f.path) + continue + + expected_size = ( + f.diff.size if f.change == PushChangeType.UPDATE_DIFF else f.size + ) + if expected_size != os.path.getsize(temporary_location): logging.error( - "Data integrity check has failed on file %s in project %s" - % (f.path, project_path), + f"Data integrity check has failed on file {f.path} in project {project_path}", exc_info=True, ) - # check if .gpkg file is valid - if is_versioned_file(dest_file): - if not f.is_valid_gpkg(): - corrupted_files.append(f.path) corrupted_files.append(f.path) if corrupted_files: @@ -1076,40 +1069,48 @@ def push_finish(transaction_id): os.renames(files_dir, target_dir) # apply gpkg updates sync_errors = {} - to_remove = [i.path for i in changes.removed] + to_remove = [i.path for i in file_changes if i.change == PushChangeType.DELETE] current_files = [f for f in project.files if f.path not in to_remove] - for updated_file in changes.updated: - # yield to gevent hub since geodiff action can take some time to prevent worker timeout - sleep(0) - current_file = next( - (i for i in current_files if i.path == updated_file.path), None - ) - if not current_file: - sync_errors[updated_file.path] = "file not found on server " - continue - - if updated_file.diff: - result = project.storage.apply_diff( - current_file, updated_file, next_version + for file in file_changes: + # for updates try to apply diff to create a full updated gpkg file or from full .gpkg try to create corresponding diff + if file.change in ( + PushChangeType.UPDATE, + PushChangeType.UPDATE_DIFF, + ) and is_versioned_file(file.path): + current_file = next( + (i for i in current_files if i.path == file.path), None ) - if result.ok(): - checksum, size = result.value - updated_file.checksum = checksum - updated_file.size = size - else: - sync_errors[updated_file.path] = ( - f"project: {project.workspace.name}/{project.name}, {result.value}" - ) + if not current_file: + sync_errors[file.path] = "file not found on server " + continue - elif is_versioned_file(updated_file.path): - result = project.storage.construct_diff( - current_file, updated_file, next_version - ) - if result.ok(): - updated_file.diff = result.value + # yield to gevent hub since geodiff action can take some time to prevent worker timeout + sleep(0) + + if file.diff: + result = project.storage.apply_diff( + current_file, file, next_version + ) + if result.ok(): + checksum, size = result.value + file.checksum = checksum + file.size = size + else: + sync_errors[file.path] = ( + f"project: {project.workspace.name}/{project.name}, {result.value}" + ) else: - # if diff cannot be constructed it would be force update - logging.warning(f"Geodiff: create changeset error {result.value}") + result = project.storage.construct_diff( + current_file, file, next_version + ) + if result.ok(): + file.diff = result.value + file.change = PushChangeType.UPDATE_DIFF + else: + # if diff cannot be constructed it would be force update + logging.warning( + f"Geodiff: create changeset error {result.value}" + ) if sync_errors: msg = "" @@ -1123,7 +1124,7 @@ def push_finish(transaction_id): project, next_version, current_user.id, - changes, + file_changes, get_ip(request), user_agent, device_id, @@ -1257,15 +1258,25 @@ def clone_project(namespace, project_name): # noqa: E501 user_agent = get_user_agent(request) device_id = get_device_id(request) # transform source files to new uploaded files - files = UploadFileSchema(context={"version": 1}, many=True).load( - FileSchema(exclude=("location",), many=True).dump(cloned_project.files) - ) - changes = UploadChanges(added=files, updated=[], removed=[]) + file_changes = [] + for file in cloned_project.files: + file_changes.append( + ProjectFileChange( + file.path, + file.checksum, + file.size, + diff=None, + mtime=None, + location=os.path.join("v1", mergin_secure_filename(file.path)), + change=PushChangeType.CREATE, + ) + ) + project_version = ProjectVersion( p, version, current_user.id, - changes, + file_changes, get_ip(request), user_agent, device_id, diff --git a/server/mergin/sync/schemas.py b/server/mergin/sync/schemas.py index 75b6f09e..617ad531 100644 --- a/server/mergin/sync/schemas.py +++ b/server/mergin/sync/schemas.py @@ -75,7 +75,7 @@ def project_user_permissions(project): class FileHistorySchema(ma.SQLAlchemyAutoSchema): mtime = DateTimeWithZ() - diff = fields.Nested(FileSchema(), attribute="diff_file", exclude=("location",)) + diff = fields.Nested(FileSchema(), attribute="diff_file") expiration = DateTimeWithZ(attribute="expiration", dump_only=True) class Meta: diff --git a/server/mergin/sync/storages/disk.py b/server/mergin/sync/storages/disk.py index 4debb255..459d704d 100644 --- a/server/mergin/sync/storages/disk.py +++ b/server/mergin/sync/storages/disk.py @@ -21,7 +21,7 @@ generate_checksum, is_versioned_file, ) -from ..files import mergin_secure_filename, ProjectFile, UploadFile, File +from ..files import ProjectDiffFile, mergin_secure_filename, ProjectFile def save_to_file(stream, path, max_size=None): @@ -245,7 +245,7 @@ def _generator(): return _generator() def apply_diff( - self, current_file: ProjectFile, upload_file: UploadFile, version: int + self, current_file: ProjectFile, upload_file: ProjectFile, version: int ) -> Result: """Apply geodiff diff file on current gpkg basefile. Creates GeodiffActionHistory record of the action. Returns checksum and size of generated file. If action fails it returns geodiff error message. @@ -313,7 +313,7 @@ def apply_diff( return Err(self.gediff_log.getvalue()) def construct_diff( - self, current_file: ProjectFile, upload_file: UploadFile, version: int + self, current_file: ProjectFile, upload_file: ProjectFile, version: int ) -> Result: """Construct geodiff diff file from uploaded gpkg and current basefile. Returns diff metadata as a result. If action fails it returns geodiff error message. @@ -345,7 +345,7 @@ def construct_diff( basefile_tmp, uploaded_file_tmp, changeset_tmp ) # create diff metadata as it would be created by other clients - diff_file = File( + diff_file = ProjectDiffFile( path=diff_name, checksum=generate_checksum(changeset_tmp), size=os.path.getsize(changeset_tmp), diff --git a/server/mergin/tests/fixtures.py b/server/mergin/tests/fixtures.py index 7cff688e..9f39909d 100644 --- a/server/mergin/tests/fixtures.py +++ b/server/mergin/tests/fixtures.py @@ -19,7 +19,7 @@ from ..stats.models import MerginInfo from . import test_project, test_workspace_id, test_project_dir, TMP_DIR from .utils import login_as_admin, initialize, cleanup, file_info -from ..sync.files import ChangesSchema +from ..sync.files import files_changes_from_upload thisdir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.join(thisdir, os.pardir)) @@ -213,12 +213,13 @@ def diff_project(app): else: # no files uploaded, hence no action needed pass - upload_changes = ChangesSchema(context={"version": i + 2}).load(change) + + file_changes = files_changes_from_upload(change, version=i + 2) pv = ProjectVersion( project, i + 2, project.creator.id, - upload_changes, + file_changes, "127.0.0.1", ) assert pv.project_size == sum(file.size for file in pv.files) diff --git a/server/mergin/tests/test_db_hooks.py b/server/mergin/tests/test_db_hooks.py index f896f99b..48c14c0e 100644 --- a/server/mergin/tests/test_db_hooks.py +++ b/server/mergin/tests/test_db_hooks.py @@ -18,7 +18,6 @@ ProjectRole, ProjectUser, ) -from ..sync.files import UploadChanges from ..auth.models import User from ..app import db from . import DEFAULT_USER @@ -40,8 +39,7 @@ def test_close_user_account(client, diff_project): # user has access to mergin user diff_project diff_project.set_role(user.id, ProjectRole.WRITER) # user contributed to another user project so he is listed in projects history - changes = UploadChanges(added=[], updated=[], removed=[]) - pv = ProjectVersion(diff_project, 11, user.id, changes, "127.0.0.1") + pv = ProjectVersion(diff_project, 11, user.id, [], "127.0.0.1") diff_project.latest_version = pv.name pv.project = diff_project db.session.add(pv) @@ -116,8 +114,9 @@ def test_remove_project(client, diff_project): # set up mergin_user = User.query.filter_by(username=DEFAULT_USER[0]).first() project_dir = Path(diff_project.storage.project_dir) - changes = UploadChanges(added=[], removed=[], updated=[]) - upload = Upload(diff_project, changes, mergin_user.id) + upload = Upload( + diff_project, {"added": [], "removed": [], "updated": []}, mergin_user.id + ) db.session.add(upload) project_id = diff_project.id user = add_user("user", "user") diff --git a/server/mergin/tests/test_project_controller.py b/server/mergin/tests/test_project_controller.py index 1b38257a..989133e6 100644 --- a/server/mergin/tests/test_project_controller.py +++ b/server/mergin/tests/test_project_controller.py @@ -38,7 +38,7 @@ PushChangeType, ProjectFilePath, ) -from ..sync.files import ChangesSchema, UploadChanges +from ..sync.files import ChangesSchema, files_changes_from_upload from ..sync.schemas import ProjectListSchema from ..sync.utils import generate_checksum, is_versioned_file, get_project_path from ..auth.models import User, UserProfile @@ -1205,13 +1205,7 @@ def test_push_to_new_project(client): project, 0, p.creator.id, - ChangesSchema(context={"version": 0}).load( - { - "added": [], - "updated": [], - "removed": [], - } - ), + [], ip="127.0.0.1", ) db.session.add(pv) @@ -1411,8 +1405,7 @@ def create_transaction(username, changes, version=1): project = Project.query.filter_by( name=test_project, workspace_id=test_workspace_id ).first() - upload_changes = ChangesSchema(context={"version": version}).load(changes) - upload = Upload(project, upload_changes, user.id) + upload = Upload(project, changes, user.id) db.session.add(upload) db.session.commit() upload_dir = os.path.join(upload.project.storage.project_dir, "tmp", upload.id) @@ -1525,7 +1518,7 @@ def test_push_finish(client): upload.project, upload.project.latest_version, upload.project.creator.id, - UploadChanges(added=[], updated=[], removed=[]), + [], "127.0.0.1", ) pv.project = upload.project @@ -2436,12 +2429,12 @@ def add_project_version(project, changes, version=None): else User.query.filter_by(username=DEFAULT_USER[0]).first() ) next_version = version or project.next_version() - upload_changes = ChangesSchema(context={"version": next_version}).load(changes) + file_changes = files_changes_from_upload(changes, version=next_version) pv = ProjectVersion( project, next_version, author.id, - upload_changes, + file_changes, ip="127.0.0.1", ) db.session.add(pv) diff --git a/server/mergin/tests/utils.py b/server/mergin/tests/utils.py index 94fc033f..0c4448a6 100644 --- a/server/mergin/tests/utils.py +++ b/server/mergin/tests/utils.py @@ -20,7 +20,7 @@ from ..auth.models import User, UserProfile from ..sync.utils import generate_location, generate_checksum from ..sync.models import Project, ProjectVersion, FileHistory, ProjectRole -from ..sync.files import UploadChanges, ChangesSchema +from ..sync.files import ProjectFileChange, PushChangeType, files_changes_from_upload from ..sync.workspace import GlobalWorkspace from ..app import db from . import json_headers, DEFAULT_USER, test_project, test_project_dir, TMP_DIR @@ -82,8 +82,7 @@ def create_project(name, workspace, user, **kwargs): p.updated = datetime.utcnow() db.session.add(p) db.session.flush() - changes = UploadChanges(added=[], updated=[], removed=[]) - pv = ProjectVersion(p, 0, user.id, changes, "127.0.0.1") + pv = ProjectVersion(p, 0, user.id, [], "127.0.0.1") db.session.add(pv) db.session.commit() @@ -156,15 +155,17 @@ def initialize(): for f in files: abs_path = os.path.join(root, f) project_files.append( - { - "path": abs_path.replace(test_project_dir, "").lstrip("/"), - "location": os.path.join( + ProjectFileChange( + path=abs_path.replace(test_project_dir, "").lstrip("/"), + checksum=generate_checksum(abs_path), + size=os.path.getsize(abs_path), + mtime=str(datetime.fromtimestamp(os.path.getmtime(abs_path))), + change=PushChangeType.CREATE, + location=os.path.join( "v1", abs_path.replace(test_project_dir, "").lstrip("/") ), - "size": os.path.getsize(abs_path), - "checksum": generate_checksum(abs_path), - "mtime": str(datetime.fromtimestamp(os.path.getmtime(abs_path))), - } + diff=None, + ) ) p.latest_version = 1 p.public = True @@ -173,14 +174,7 @@ def initialize(): db.session.add(p) db.session.commit() - upload_changes = ChangesSchema(context={"version": 1}).load( - { - "added": project_files, - "updated": [], - "removed": [], - } - ) - pv = ProjectVersion(p, 1, user.id, upload_changes, "127.0.0.1") + pv = ProjectVersion(p, 1, user.id, project_files, "127.0.0.1") db.session.add(pv) db.session.commit() @@ -285,7 +279,7 @@ def create_blank_version(project): project, project.next_version(), project.creator.id, - UploadChanges(added=[], updated=[], removed=[]), + [], "127.0.0.1", ) db.session.add(pv) @@ -355,14 +349,12 @@ def push_change(project, action, path, src_dir): else: return - upload_changes = ChangesSchema(context={"version": project.next_version()}).load( - changes - ) + file_changes = files_changes_from_upload(changes, version=project.next_version()) pv = ProjectVersion( project, project.next_version(), project.creator.id, - upload_changes, + file_changes, "127.0.0.1", ) db.session.add(pv)