diff --git a/bin/dvuploader-macos-latest b/bin/dvuploader-macos-latest index 149de3a..3aad2b3 100755 Binary files a/bin/dvuploader-macos-latest and b/bin/dvuploader-macos-latest differ diff --git a/bin/dvuploader-ubuntu-latest b/bin/dvuploader-ubuntu-latest index b49b823..85c7a5e 100755 Binary files a/bin/dvuploader-ubuntu-latest and b/bin/dvuploader-ubuntu-latest differ diff --git a/bin/dvuploader-windows-latest.exe b/bin/dvuploader-windows-latest.exe index d4d0c67..084b3ff 100644 Binary files a/bin/dvuploader-windows-latest.exe and b/bin/dvuploader-windows-latest.exe differ diff --git a/dvuploader/checksum.py b/dvuploader/checksum.py index 1260fbd..595c832 100644 --- a/dvuploader/checksum.py +++ b/dvuploader/checksum.py @@ -1,5 +1,6 @@ import hashlib from enum import Enum +import os from typing import Callable from pydantic import BaseModel, Field diff --git a/dvuploader/directupload.py b/dvuploader/directupload.py index 8b0dddf..6e35213 100644 --- a/dvuploader/directupload.py +++ b/dvuploader/directupload.py @@ -7,12 +7,12 @@ import requests from dotted_dict import DottedDict from requests.exceptions import HTTPError -from requests.models import PreparedRequest from tqdm import tqdm from tqdm.utils import CallbackIOWrapper from dvuploader.file import File from dvuploader.chunkstream import ChunkStream +from dvuploader.utils import build_url global MAX_RETRIES @@ -20,6 +20,7 @@ TICKET_ENDPOINT = "/api/datasets/:persistentId/uploadurls" ADD_FILE_ENDPOINT = "/api/datasets/:persistentId/addFiles" UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add?persistentId=" +REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace" def direct_upload( @@ -28,6 +29,7 @@ def direct_upload( dataverse_url: str, api_token: str, position: int, + n_parallel_uploads: int, ) -> bool: """ Uploads a file to a Dataverse collection using direct upload. @@ -38,6 +40,7 @@ def direct_upload( dataverse_url (str): The URL of the Dataverse instance to upload to. api_token (str): The API token to use for authentication. position (int): The position of the file in the list of files to upload. + n_parallel_uploads (int): The number of parallel uploads to perform. Returns: bool: True if the upload was successful, False otherwise. @@ -68,6 +71,7 @@ def direct_upload( dataverse_url=dataverse_url, api_token=api_token, pbar=pbar, + n_parallel_uploads=n_parallel_uploads, ) result = _add_file_to_ds( @@ -75,6 +79,7 @@ def direct_upload( persistent_id, api_token, file, + n_parallel_uploads, ) if result is True: @@ -111,7 +116,7 @@ def _request_ticket( """ # Build request URL - query = _build_url( + query = build_url( endpoint=TICKET_ENDPOINT, dataverse_url=dataverse_url, key=api_token, @@ -126,26 +131,11 @@ def _request_ticket( raise HTTPError( f"Could not request a ticket for dataset '{persistent_id}' at '{dataverse_url}' \ \n\n{json.dumps(response.json(), indent=2)}" - ) + ) # type: ignore return DottedDict(response.json()["data"]) -def _build_url( - dataverse_url: str, - endpoint: str, - **kwargs, -) -> str: - """Builds a URL string, given access points and credentials""" - - req = PreparedRequest() - req.prepare_url(urljoin(dataverse_url, endpoint), kwargs) - - assert req.url is not None, f"Could not build URL for '{dataverse_url}'" - - return req.url - - def _upload_singlepart( response: Dict, filepath: str, @@ -179,7 +169,7 @@ def _upload_singlepart( raise HTTPError( f"Could not upload file \ \n\n{resp.headers}" - ) + ) # type: ignore return storage_identifier @@ -190,6 +180,7 @@ def _upload_multipart( dataverse_url: str, api_token: str, pbar: tqdm, + n_parallel_uploads: int, ): """ Uploads a file to Dataverse using multipart upload. @@ -200,6 +191,7 @@ def _upload_multipart( dataverse_url (str): The URL of the Dataverse instance. api_token (str): The API token for the Dataverse instance. pbar (tqdm): A progress bar to track the upload progress. + n_parallel_uploads (int): The number of parallel uploads to perform. Returns: str: The storage identifier for the uploaded file. @@ -228,7 +220,10 @@ def _upload_multipart( ) # Execute upload - responses = grequests.map(rs) + responses = grequests.map( + requests=rs, + size=n_parallel_uploads, + ) e_tags = [response.headers["ETag"] for response in responses] except Exception as e: @@ -302,7 +297,7 @@ def _complete_upload( raise HTTPError( f"Could not complete upload \ \n\n{json.dumps(response.json(), indent=2)}" - ) + ) # type: ignore def _abort_upload( @@ -321,8 +316,16 @@ def _add_file_to_ds( file: File, ): headers = {"X-Dataverse-key": api_token} - url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid) - payload = {"jsonData": file.json(by_alias=True)} + + if not file.to_replace: + url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid) + else: + url = build_url( + dataverse_url=dataverse_url, + endpoint=REPLACE_ENDPOINT.format(FILE_ID=file.file_id), + ) + + payload = {"jsonData": file.json(by_alias=True, exclude={"to_replace", "file_id"})} for _ in range(MAX_RETRIES): response = requests.post(url, headers=headers, files=payload) diff --git a/dvuploader/dvuploader.py b/dvuploader/dvuploader.py index 158866c..5b8f157 100644 --- a/dvuploader/dvuploader.py +++ b/dvuploader/dvuploader.py @@ -1,16 +1,21 @@ import grequests import requests -import json import os -from typing import Dict, List -from urllib.parse import urljoin +from typing import Dict, List, Optional -from pydantic import BaseModel, validator +from pydantic import BaseModel from joblib import Parallel, delayed from dotted_dict import DottedDict -from dvuploader.directupload import direct_upload +from dvuploader.directupload import ( + TICKET_ENDPOINT, + _abort_upload, + _validate_ticket_response, + direct_upload, +) from dvuploader.file import File +from dvuploader.nativeupload import native_upload +from dvuploader.utils import build_url, retrieve_dataset_files class DVUploader(BaseModel): @@ -34,6 +39,7 @@ def upload( dataverse_url: str, api_token: str, n_jobs: int = -1, + n_parallel_uploads: int = 1, ) -> None: """ Uploads the files to the specified Dataverse repository in parallel. @@ -43,6 +49,7 @@ def upload( dataverse_url (str): The URL of the Dataverse repository. api_token (str): The API token for the Dataverse repository. n_jobs (int): The number of parallel jobs to run. Defaults to -1. + n_parallel_uploads (int): The number of parallel uploads to execute. In the case of direct upload, this restricts the amount of parallel chunks per upload. Please use n_jobs to control parallel files. Returns: None @@ -57,26 +64,42 @@ def upload( # Sort files by size files = sorted( - self.files, key=lambda x: os.path.getsize(x.filepath), reverse=True + self.files, + key=lambda x: os.path.getsize(x.filepath), + reverse=True, ) if not self.files: print("\nāŒ No files to upload\n") return - # Upload files in parallel + # Check if direct upload is supported + has_direct_upload = self._has_direct_upload( + dataverse_url=dataverse_url, + api_token=api_token, + persistent_id=persistent_id, + ) + print("\nāš ļø Direct upload not supported. Falling back to Native API.") + print(f"\nšŸš€ Uploading files") - Parallel(n_jobs=n_jobs, backend="threading")( - delayed(direct_upload)( - file=file, + if not has_direct_upload: + self._execute_native_uploads( + files=files, dataverse_url=dataverse_url, api_token=api_token, persistent_id=persistent_id, - position=position, + n_parallel_uploads=n_parallel_uploads, + ) + else: + self._parallel_direct_upload( + files=files, + dataverse_url=dataverse_url, + api_token=api_token, + persistent_id=persistent_id, + n_jobs=n_jobs, + n_parallel_uploads=n_parallel_uploads, ) - for position, file in enumerate(files) - ) print("šŸŽ‰ Done!\n") @@ -97,7 +120,7 @@ def _check_duplicates( Prints a message for each file that already exists in the dataset with the same checksum. """ - ds_files = self._retrieve_dataset_files( + ds_files = retrieve_dataset_files( dataverse_url=dataverse_url, persistent_id=persistent_id, api_token=api_token, @@ -108,7 +131,11 @@ def _check_duplicates( to_remove = [] for file in self.files: - if any(map(lambda dsFile: self._check_hashes(file, dsFile), ds_files)): + has_same_hash = any( + map(lambda dsFile: self._check_hashes(file, dsFile), ds_files) + ) + + if has_same_hash and file.checksum: print( f"ā”œā”€ā”€ File '{file.fileName}' already exists with same {file.checksum.type} hash - Skipping upload." ) @@ -116,13 +143,45 @@ def _check_duplicates( else: print(f"ā”œā”€ā”€ File '{file.fileName}' is new - Uploading.") + # If present in dataset, replace file + file.file_id = self._get_file_id(file, ds_files) + file.to_replace = True if file.file_id else False + for file in to_remove: self.files.remove(file) print("šŸŽ‰ Done") @staticmethod - def _check_hashes(file: File, dsFile: Dict): + def _get_file_id( + file: File, + ds_files: List[DottedDict], + ) -> Optional[str]: + """ + Get the file ID for a given file in a dataset. + + Args: + file (File): The file object to find the ID for. + ds_files (List[Dict]): List of dictionary objects representing dataset files. + persistent_id (str): The persistent ID of the dataset. + + Returns: + str: The ID of the file. + + Raises: + ValueError: If the file cannot be found in the dataset. + """ + + # Find the file that matches label and directoryLabel + for ds_file in ds_files: + dspath = os.path.join(ds_file.get("directoryLabel", ""), ds_file.label) + fpath = os.path.join(file.directoryLabel, file.fileName) # type: ignore + + if dspath == fpath: + return ds_file.dataFile.id + + @staticmethod + def _check_hashes(file: File, dsFile: DottedDict): """ Checks if a file has the same checksum as a file in the dataset. @@ -134,41 +193,126 @@ def _check_hashes(file: File, dsFile: Dict): bool: True if the files have the same checksum, False otherwise. """ + if not file.checksum: + return False + hash_algo, hash_value = tuple(dsFile.dataFile.checksum.values()) return file.checksum.value == hash_value and file.checksum.type == hash_algo @staticmethod - def _retrieve_dataset_files( + def _has_direct_upload( dataverse_url: str, + api_token: str, persistent_id: str, + ) -> bool: + """Checks if the response from the ticket request contains a direct upload URL""" + + query = build_url( + endpoint=TICKET_ENDPOINT, + dataverse_url=dataverse_url, + key=api_token, + persistentId=persistent_id, + size=1024, + ) + + # Send HTTP request + response = requests.get(query).json() + expected_error = "Direct upload not supported for files in this dataset" + + if "message" in response and expected_error in response["message"]: + return False + + # Abort test upload for now, if direct upload is supported + data = DottedDict(response.json()["data"]) + _validate_ticket_response(data) + _abort_upload( + data.abort, + dataverse_url, + api_token, + ) + + return True + + @staticmethod + def _execute_native_uploads( + files: List[File], + dataverse_url: str, api_token: str, + persistent_id: str, + n_parallel_uploads: int, ): """ - Retrieve the files of a specific dataset from a Dataverse repository. + Executes native uploads for the given files in parallel. - Parameters: - dataverse_url (str): The base URL of the Dataverse repository. - persistent_id (str): The persistent identifier (PID) of the dataset. + Args: + files (List[File]): The list of File objects to be uploaded. + dataverse_url (str): The URL of the Dataverse repository. + api_token (str): The API token for the Dataverse repository. + persistent_id (str): The persistent identifier of the Dataverse dataset. + n_parallel_uploads (int): The number of parallel uploads to execute. Returns: - list: A list of files in the dataset. + List[requests.Response]: The list of responses for each file upload. + """ - Raises: - HTTPError: If the request to the Dataverse repository fails. + tasks = [ + native_upload( + file=file, + dataverse_url=dataverse_url, + api_token=api_token, + persistent_id=persistent_id, + position=position, + ) + for position, file in enumerate(files) + ] + + # Execute tasks + responses = grequests.map(tasks, size=n_parallel_uploads) + + if not all(map(lambda x: x.status_code == 200, responses)): + errors = "\n".join( + ["\n\nāŒ Failed to upload files:"] + + [ + f"ā”œā”€ā”€ File '{file.fileName}' could not be uploaded: {response.status_code} {response.json()['message']}" + for file, response in zip(files, responses) + if response.status_code != 200 + ] + ) + + print(errors, "\n") + + @staticmethod + def _parallel_direct_upload( + files: List[File], + dataverse_url: str, + api_token: str, + persistent_id: str, + n_parallel_uploads: int, + n_jobs: int = -1, + ) -> None: """ + Perform parallel direct upload of files to the specified Dataverse repository. - DATASET_ENDPOINT = "/api/datasets/:persistentId/?persistentId={0}" + Args: + files (List[File]): A list of File objects to be uploaded. + dataverse_url (str): The URL of the Dataverse repository. + api_token (str): The API token for the Dataverse repository. + persistent_id (str): The persistent identifier of the Dataverse dataset. + n_jobs (int): The number of parallel jobs to run. Defaults to -1. - response = requests.get( - urljoin(dataverse_url, DATASET_ENDPOINT.format(persistent_id)), - headers={"X-Dataverse-key": api_token}, - ) + Returns: + None + """ - if response.status_code != 200: - raise requests.HTTPError( - f"Could not download dataset '{persistent_id}' at '{dataverse_url}' \ - \n\n{json.dumps(response.json(), indent=2)}" + Parallel(n_jobs=n_jobs, backend="threading")( + delayed(direct_upload)( + file=file, + dataverse_url=dataverse_url, + api_token=api_token, + persistent_id=persistent_id, + position=position, + n_parallel_uploads=n_parallel_uploads, ) - - return DottedDict(response.json()).data.latestVersion.files + for position, file in enumerate(files) + ) diff --git a/dvuploader/file.py b/dvuploader/file.py index e4263fd..b0b0940 100644 --- a/dvuploader/file.py +++ b/dvuploader/file.py @@ -1,7 +1,7 @@ import os from typing import List, Optional -from pydantic import BaseModel, Field, validator, ValidationError +from pydantic import BaseModel, Field, validator from dvuploader.checksum import Checksum, ChecksumTypes @@ -20,6 +20,8 @@ class File(BaseModel): storageIdentifier: Optional[str] = None fileName: Optional[str] = None checksum: Optional[Checksum] = None + to_replace: bool = False + file_id: Optional[str] = None @staticmethod def _validate_filepath(path): @@ -29,8 +31,7 @@ def _validate_filepath(path): raise TypeError(f"Filepath {path} is not a file.") elif not os.access(path, os.R_OK): raise TypeError(f"Filepath {path} is not readable.") - elif os.path.getsize(path) == 0: - raise ValueError(f"Filepath {path} is empty.") + return path @validator("fileName", always=True) diff --git a/dvuploader/nativeupload.py b/dvuploader/nativeupload.py new file mode 100644 index 0000000..9a7a0f8 --- /dev/null +++ b/dvuploader/nativeupload.py @@ -0,0 +1,79 @@ +import json +import os +import grequests +from dvuploader.directupload import _setup_pbar +from dvuploader.file import File +from dvuploader.utils import build_url, retrieve_dataset_files +from tqdm.utils import CallbackIOWrapper + + +NATIVE_UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add" +NATIVE_REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace" + + +def native_upload( + file: File, + dataverse_url: str, + api_token: str, + persistent_id: str, + position: int, +): + """ + Uploads a file to a Dataverse repository using the native upload method. + + Args: + file (File): The file to be uploaded. + dataverse_url (str): The URL of the Dataverse repository. + api_token (str): The API token for authentication. + persistent_id (str): The persistent identifier of the dataset. + position (int): The position of the file within the dataset. + + Returns: + Response: The response object from the upload request. + """ + + pbar = _setup_pbar(file.filepath, position) + + if not file.to_replace: + url = build_url( + dataverse_url=dataverse_url, + endpoint=NATIVE_UPLOAD_ENDPOINT, + persistentId=persistent_id, + ) + else: + url = build_url( + dataverse_url=dataverse_url, + endpoint=NATIVE_REPLACE_ENDPOINT.format(FILE_ID=file.file_id), + ) + + header = {"X-Dataverse-key": api_token} + json_data = { + "description": file.description, + "forceReplace": "true", + "directoryLabel": file.directoryLabel, + "categories": file.categories, + "restrict": file.restrict, + "forceReplace": True, + } + + files = { + "jsonData": json.dumps(json_data), + "file": ( + os.path.basename(file.filepath), + CallbackIOWrapper(pbar.update, open(file.filepath, "rb"), "read"), + ), + } + + def _response_hook(response, *args, **kwargs): + filesize = os.path.getsize(file.filepath) + pbar.reset(filesize / 1024) + pbar.update(filesize / 1024) + pbar.close() + return response + + return grequests.post( + url=url, + headers=header, + files=files, + hooks=dict(response=_response_hook), + ) diff --git a/dvuploader/utils.py b/dvuploader/utils.py new file mode 100644 index 0000000..b7ecb8f --- /dev/null +++ b/dvuploader/utils.py @@ -0,0 +1,55 @@ +import json +from urllib.parse import urljoin +from requests import PreparedRequest +import requests +from dotted_dict import DottedDict + + +def build_url( + dataverse_url: str, + endpoint: str, + **kwargs, +) -> str: + """Builds a URL string, given access points and credentials""" + + req = PreparedRequest() + req.prepare_url(urljoin(dataverse_url, endpoint), kwargs) + + assert req.url is not None, f"Could not build URL for '{dataverse_url}'" + + return req.url + + +def retrieve_dataset_files( + dataverse_url: str, + persistent_id: str, + api_token: str, +): + """ + Retrieve the files of a specific dataset from a Dataverse repository. + + Parameters: + dataverse_url (str): The base URL of the Dataverse repository. + persistent_id (str): The persistent identifier (PID) of the dataset. + + Returns: + list: A list of files in the dataset. + + Raises: + HTTPError: If the request to the Dataverse repository fails. + """ + + DATASET_ENDPOINT = "/api/datasets/:persistentId/?persistentId={0}" + + response = requests.get( + urljoin(dataverse_url, DATASET_ENDPOINT.format(persistent_id)), + headers={"X-Dataverse-key": api_token}, + ) + + if response.status_code != 200: + raise requests.HTTPError( + f"Could not download dataset '{persistent_id}' at '{dataverse_url}' \ + \n\n{json.dumps(response.json(), indent=2)}" + ) # type: ignore + + return DottedDict(response.json()).data.latestVersion.files