diff --git a/mergin/client.py b/mergin/client.py index 14ebe033..169b2198 100644 --- a/mergin/client.py +++ b/mergin/client.py @@ -14,6 +14,7 @@ import ssl from enum import Enum, auto import re +import typing import warnings from .common import ClientError, LoginError, InvalidProject @@ -22,6 +23,8 @@ download_file_finalize, download_project_async, download_file_async, + download_files_async, + download_files_finalize, download_diffs_async, download_project_finalize, download_project_wait, @@ -1127,3 +1130,63 @@ def has_writing_permissions(self, project_path): """ info = self.project_info(project_path) return info["permissions"]["upload"] + + def reset_local_changes(self, directory: str, files_to_reset: typing.List[str] = None) -> None: + """ + Reset local changes to either all files or only listed files. + Added files are removed, removed files are brought back and updates are discarded. + + :param directory: Project's directory + :type directory: String + :param files_to_reset List of files to reset, relative paths of file + :type files_to_reset: List of strings, default None + """ + all_files = files_to_reset is None + + mp = MerginProject(directory) + + current_version = mp.version() + + push_changes = mp.get_push_changes() + + files_download = [] + + # remove all added files + for file in push_changes["added"]: + if all_files or file["path"] in files_to_reset: + os.remove(mp.fpath(file["path"])) + + # update files get override with previous version + for file in push_changes["updated"]: + if all_files or file["path"] in files_to_reset: + if mp.is_versioned_file(file["path"]): + mp.geodiff.make_copy_sqlite(mp.fpath_meta(file["path"]), mp.fpath(file["path"])) + else: + files_download.append(file["path"]) + + # removed files are redownloaded + for file in push_changes["removed"]: + if all_files or file["path"] in files_to_reset: + files_download.append(file["path"]) + + if files_download: + self.download_files(directory, files_download, version=current_version) + + def download_files( + self, project_dir: str, file_paths: typing.List[str], output_paths: typing.List[str] = None, version: str = None + ): + """ + Download project files at specified version. Get the latest if no version specified. + + :param project_dir: project local directory + :type project_dir: String + :param file_path: List of relative paths of files to download in the project directory + :type file_path: List[String] + :param output_paths: List of paths for files to download to. Should be same length of as file_path. Default is `None` which means that files are downloaded into MerginProject at project_dir. + :type output_paths: List[String] + :param version: optional version tag for downloaded file + :type version: String + """ + job = download_files_async(self, project_dir, file_paths, output_paths, version=version) + pull_project_wait(job) + download_files_finalize(job) diff --git a/mergin/client_pull.py b/mergin/client_pull.py index 7baf9463..645a61ac 100644 --- a/mergin/client_pull.py +++ b/mergin/client_pull.py @@ -15,6 +15,7 @@ import pprint import shutil import tempfile +import typing import concurrent.futures @@ -621,77 +622,14 @@ def download_file_async(mc, project_dir, file_path, output_file, version): Starts background download project file at specified version. Returns handle to the pending download. """ - mp = MerginProject(project_dir) - project_path = mp.project_full_name() - ver_info = f"at version {version}" if version is not None else "at latest version" - mp.log.info(f"Getting {file_path} {ver_info}") - latest_proj_info = mc.project_info(project_path) - if version: - project_info = mc.project_info(project_path, version=version) - else: - project_info = latest_proj_info - mp.log.info(f"Got project info. version {project_info['version']}") - - # set temporary directory for download - temp_dir = tempfile.mkdtemp(prefix="mergin-py-client-") - - download_list = [] - update_tasks = [] - total_size = 0 - # None can not be used to indicate latest version of the file, so - # it is necessary to pass actual version. - if version is None: - version = latest_proj_info["version"] - for file in project_info["files"]: - if file["path"] == file_path: - file["version"] = version - items = _download_items(file, temp_dir) - is_latest_version = version == latest_proj_info["version"] - task = UpdateTask(file["path"], items, output_file, latest_version=is_latest_version) - download_list.extend(task.download_queue_items) - for item in task.download_queue_items: - total_size += item.size - update_tasks.append(task) - break - if not download_list: - warn = f"No {file_path} exists at version {version}" - mp.log.warning(warn) - shutil.rmtree(temp_dir) - raise ClientError(warn) - - mp.log.info(f"will download file {file_path} in {len(download_list)} chunks, total size {total_size}") - job = DownloadJob(project_path, total_size, version, update_tasks, download_list, temp_dir, mp, project_info) - job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) - job.futures = [] - for item in download_list: - future = job.executor.submit(_do_download, item, mc, mp, project_path, job) - job.futures.append(future) - - return job + return download_files_async(mc, project_dir, [file_path], [output_file], version) def download_file_finalize(job): """ To be called when download_file_async is finished """ - job.executor.shutdown(wait=True) - - # make sure any exceptions from threads are not lost - for future in job.futures: - if future.exception() is not None: - raise future.exception() - - job.mp.log.info("--- download finished") - - temp_dir = None - for task in job.update_tasks: - task.apply(job.directory, job.mp) - if task.download_queue_items: - temp_dir = os.path.dirname(task.download_queue_items[0].download_file_path) - - # Remove temporary download directory - if temp_dir is not None: - shutil.rmtree(temp_dir) + download_files_finalize(job) def download_diffs_async(mc, project_directory, file_path, versions): @@ -804,3 +742,103 @@ def download_diffs_finalize(job): job.mp.log.info("--- diffs pull finished") return diffs + + +def download_files_async( + mc, project_dir: str, file_paths: typing.List[str], output_paths: typing.List[str], version: str +): + """ + Starts background download project files at specified version. + Returns handle to the pending download. + """ + mp = MerginProject(project_dir) + project_path = mp.project_full_name() + ver_info = f"at version {version}" if version is not None else "at latest version" + mp.log.info(f"Getting [{', '.join(file_paths)}] {ver_info}") + latest_proj_info = mc.project_info(project_path) + if version: + project_info = mc.project_info(project_path, version=version) + else: + project_info = latest_proj_info + mp.log.info(f"Got project info. version {project_info['version']}") + + # set temporary directory for download + temp_dir = tempfile.mkdtemp(prefix="mergin-py-client-") + + if output_paths is None: + output_paths = [] + for file in file_paths: + output_paths.append(mp.fpath(file)) + + if len(output_paths) != len(file_paths): + warn = "Output file paths are not of the same length as file paths. Cannot store required files." + mp.log.warning(warn) + shutil.rmtree(temp_dir) + raise ClientError(warn) + + download_list = [] + update_tasks = [] + total_size = 0 + # None can not be used to indicate latest version of the file, so + # it is necessary to pass actual version. + if version is None: + version = latest_proj_info["version"] + for file in project_info["files"]: + if file["path"] in file_paths: + index = file_paths.index(file["path"]) + file["version"] = version + items = _download_items(file, temp_dir) + is_latest_version = version == latest_proj_info["version"] + task = UpdateTask(file["path"], items, output_paths[index], latest_version=is_latest_version) + download_list.extend(task.download_queue_items) + for item in task.download_queue_items: + total_size += item.size + update_tasks.append(task) + + missing_files = [] + files_to_download = [] + project_file_paths = [file["path"] for file in project_info["files"]] + for file in file_paths: + if file not in project_file_paths: + missing_files.append(file) + else: + files_to_download.append(file) + + if not download_list or missing_files: + warn = f"No [{', '.join(missing_files)}] exists at version {version}" + mp.log.warning(warn) + shutil.rmtree(temp_dir) + raise ClientError(warn) + + mp.log.info( + f"will download files [{', '.join(files_to_download)}] in {len(download_list)} chunks, total size {total_size}" + ) + job = DownloadJob(project_path, total_size, version, update_tasks, download_list, temp_dir, mp, project_info) + job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) + job.futures = [] + for item in download_list: + future = job.executor.submit(_do_download, item, mc, mp, project_path, job) + job.futures.append(future) + + return job + + +def download_files_finalize(job): + """ + To be called when download_file_async is finished + """ + job.executor.shutdown(wait=True) + + # make sure any exceptions from threads are not lost + for future in job.futures: + if future.exception() is not None: + raise future.exception() + + job.mp.log.info("--- download finished") + + for task in job.update_tasks: + task.apply(job.directory, job.mp) + + # Remove temporary download directory + if job.directory is not None and os.path.exists(job.directory): + shutil.rmtree(job.directory) diff --git a/mergin/test/test_client.py b/mergin/test/test_client.py index 36cbd0d2..e4ca10d8 100644 --- a/mergin/test/test_client.py +++ b/mergin/test/test_client.py @@ -1021,8 +1021,8 @@ def test_download_file(mc): assert check_gpkg_same_content(mp, f_downloaded, expected) # make sure there will be exception raised if a file doesn't exist in the version - with pytest.raises(ClientError, match=f"No {f_updated} exists at version v5"): - mc.download_file(project_dir, f_updated, f_downloaded, version=f"v5") + with pytest.raises(ClientError, match=f"No \\[{f_updated}\\] exists at version v5"): + mc.download_file(project_dir, f_updated, f_downloaded, version="v5") def test_download_diffs(mc): @@ -2005,6 +2005,127 @@ def test_clean_diff_files(mc): assert diff_files == [] +def test_reset_local_changes(mc: MerginClient): + test_project = f"test_reset_local_changes" + project = API_USER + "/" + test_project + project_dir = os.path.join(TMP_DIR, test_project) # primary project dir for updates + project_dir_2 = os.path.join(TMP_DIR, test_project + "_v2") # primary project dir for updates + + cleanup(mc, project, [project_dir]) + # create remote project + shutil.copytree(TEST_DATA_DIR, project_dir) + mc.create_project_and_push(test_project, project_dir) + + # test push changes with diffs: + mp = MerginProject(project_dir) + + # test with no changes, should pass by doing nothing + mc.reset_local_changes(project_dir) + + f_updated = "base.gpkg" + shutil.copy(mp.fpath("inserted_1_A.gpkg"), mp.fpath(f_updated)) + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_test.txt")) + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_dir/new_test.txt")) + os.remove(mp.fpath("test.txt")) + os.remove(mp.fpath("test_dir/test2.txt")) + with open(mp.fpath("test3.txt"), mode="a", encoding="utf-8") as file: + file.write(" Add some text.") + + # push changes prior to reset + mp = MerginProject(project_dir) + push_changes = mp.get_push_changes() + + assert len(push_changes["added"]) == 2 + assert len(push_changes["removed"]) == 2 + assert len(push_changes["updated"]) == 2 + + # reset all files back + mc.reset_local_changes(project_dir) + + # push changes after the reset + mp = MerginProject(project_dir) + push_changes = mp.get_push_changes() + + assert len(push_changes["added"]) == 0 + assert len(push_changes["removed"]) == 0 + assert len(push_changes["updated"]) == 0 + + cleanup(mc, project, [project_dir]) + # create remote project + shutil.copytree(TEST_DATA_DIR, project_dir) + mc.create_project_and_push(test_project, project_dir) + + # test push changes with diffs: + mp = MerginProject(project_dir) + + shutil.copy(mp.fpath("inserted_1_A.gpkg"), mp.fpath(f_updated)) + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_test.txt")) + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_dir/new_test.txt")) + os.remove(mp.fpath("test.txt")) + os.remove(mp.fpath("test_dir/test2.txt")) + + # push changes prior to reset + mp = MerginProject(project_dir) + push_changes = mp.get_push_changes() + + assert len(push_changes["added"]) == 2 + assert len(push_changes["removed"]) == 2 + assert len(push_changes["updated"]) == 1 + + # reset local changes only to certain files, one added and one removed + mc.reset_local_changes(project_dir, files_to_reset=["new_test.txt", "test_dir/test2.txt"]) + + # push changes after the reset + mp = MerginProject(project_dir) + push_changes = mp.get_push_changes() + + assert len(push_changes["added"]) == 1 + assert len(push_changes["removed"]) == 1 + assert len(push_changes["updated"]) == 1 + + cleanup(mc, project, [project_dir, project_dir_2]) + # create remote project + shutil.copytree(TEST_DATA_DIR, project_dir) + mc.create_project_and_push(test_project, project_dir) + + # test push changes with diffs: + mp = MerginProject(project_dir) + + # make changes creating two another versions + shutil.copy(mp.fpath("inserted_1_A.gpkg"), mp.fpath(f_updated)) + mc.push_project(project_dir) + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_test.txt")) + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_dir/new_test.txt")) + mc.push_project(project_dir) + os.remove(mp.fpath("test.txt")) + os.remove(mp.fpath("test_dir/test2.txt")) + + # download version 2 and create MerginProject for it + mc.download_project(project, project_dir_2, version="v2") + mp = MerginProject(project_dir_2) + + # make some changes + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_test.txt")) + shutil.copy(mp.fpath("test.txt"), mp.fpath("new_dir/new_test.txt")) + os.remove(mp.fpath("test.txt")) + os.remove(mp.fpath("test_dir/test2.txt")) + + # check changes + push_changes = mp.get_push_changes() + assert len(push_changes["added"]) == 2 + assert len(push_changes["removed"]) == 2 + assert len(push_changes["updated"]) == 0 + + # reset back to original version we had - v2 + mc.reset_local_changes(project_dir_2) + + # push changes after the reset - should be none + push_changes = mp.get_push_changes() + assert len(push_changes["added"]) == 0 + assert len(push_changes["removed"]) == 0 + assert len(push_changes["updated"]) == 0 + + def test_project_metadata(mc): test_project = "test_project_metadata" project = API_USER + "/" + test_project @@ -2038,3 +2159,58 @@ def test_project_metadata(mc): assert mp.project_name() == test_project assert mp.workspace_name() == API_USER assert mp.version() == "v0" + + +def test_download_files(mc: MerginClient): + """Test downloading files at specified versions.""" + test_project = "test_download_files" + project = API_USER + "/" + test_project + project_dir = os.path.join(TMP_DIR, test_project) + f_updated = "base.gpkg" + download_dir = os.path.join(TMP_DIR, "test-download-files-tmp") + + cleanup(mc, project, [project_dir, download_dir]) + + mp = create_versioned_project(mc, test_project, project_dir, f_updated) + + project_info = mc.project_info(project) + assert project_info["version"] == "v5" + assert project_info["id"] == mp.project_id() + + # Versioned file should have the following content at versions 2-4 + expected_content = ("inserted_1_A.gpkg", "inserted_1_A_mod.gpkg", "inserted_1_B.gpkg") + + downloaded_file = os.path.join(download_dir, f_updated) + + # if output_paths is specified look at that location + for ver in range(2, 5): + mc.download_files(project_dir, [f_updated], [downloaded_file], version=f"v{ver}") + expected = os.path.join(TEST_DATA_DIR, expected_content[ver - 2]) # GeoPackage with expected content + assert check_gpkg_same_content(mp, downloaded_file, expected) + + # if output_paths is not specified look in the mergin project folder + for ver in range(2, 5): + mc.download_files(project_dir, [f_updated], version=f"v{ver}") + expected = os.path.join(TEST_DATA_DIR, expected_content[ver - 2]) # GeoPackage with expected content + assert check_gpkg_same_content(mp, mp.fpath(f_updated), expected) + + # download two files from v1 and check their content + file_2 = "test.txt" + downloaded_file_2 = os.path.join(download_dir, file_2) + + mc.download_files(project_dir, [f_updated, file_2], [downloaded_file, downloaded_file_2], version="v1") + assert check_gpkg_same_content(mp, downloaded_file, os.path.join(TEST_DATA_DIR, f_updated)) + + with open(os.path.join(TEST_DATA_DIR, file_2), mode="r", encoding="utf-8") as file: + content_exp = file.read() + + with open(os.path.join(download_dir, file_2), mode="r", encoding="utf-8") as file: + content = file.read() + assert content_exp == content + + # make sure there will be exception raised if a file doesn't exist in the version + with pytest.raises(ClientError, match=f"No \\[{f_updated}\\] exists at version v5"): + mc.download_files(project_dir, [f_updated], version="v5") + + with pytest.raises(ClientError, match=f"No \\[non_existing\\.file\\] exists at version v3"): + mc.download_files(project_dir, [f_updated, "non_existing.file"], version="v3")