diff --git a/dvuploader/dvuploader.py b/dvuploader/dvuploader.py index 0653e05..d6aa9bb 100644 --- a/dvuploader/dvuploader.py +++ b/dvuploader/dvuploader.py @@ -1,4 +1,5 @@ import asyncio +import json from urllib.parse import urljoin import requests import os @@ -19,6 +20,7 @@ from dvuploader.nativeupload import native_upload from dvuploader.utils import build_url, retrieve_dataset_files, setup_pbar + class DVUploader(BaseModel): """ A class for uploading files to a Dataverse repository. @@ -153,10 +155,7 @@ async def _validate_and_hash_files(self, verbose: bool): if not verbose: tasks = [ - self._validate_and_hash_file( - file=file, - verbose=self.verbose - ) + self._validate_and_hash_file(file=file, verbose=self.verbose) for file in self.files ] @@ -175,10 +174,7 @@ async def _validate_and_hash_files(self, verbose: bool): tasks = [ self._validate_and_hash_file( - file=file, - progress=progress, - task_id=task, - verbose=self.verbose + file=file, progress=progress, task_id=task, verbose=self.verbose ) for file in self.files ] @@ -197,7 +193,7 @@ async def _validate_and_hash_file( file.extract_file_name_hash_file() if verbose: - progress.update(task_id, advance=1) # type: ignore + progress.update(task_id, advance=1) # type: ignore def _check_duplicates( self, @@ -240,7 +236,7 @@ def _check_duplicates( map(lambda dsFile: self._check_hashes(file, dsFile), ds_files) ) - if has_same_hash and file.checksum: + if has_same_hash: n_skip_files += 1 table.add_row( file.file_name, "[bright_black]Same hash", "[bright_black]Skip" @@ -316,12 +312,14 @@ def _check_hashes(file: File, dsFile: Dict): return False hash_algo, hash_value = tuple(dsFile["dataFile"]["checksum"].values()) + path = os.path.join( + dsFile.get("directoryLabel", ""), dsFile["dataFile"]["filename"] + ) return ( file.checksum.value == hash_value and file.checksum.type == hash_algo - and file.file_name == dsFile["label"] - and file.directory_label == dsFile.get("directoryLabel", "") + and path == os.path.join(file.directory_label, file.file_name) # type: ignore ) @staticmethod diff --git a/dvuploader/nativeupload.py b/dvuploader/nativeupload.py index 16ba67b..bb5c3fb 100644 --- a/dvuploader/nativeupload.py +++ b/dvuploader/nativeupload.py @@ -3,6 +3,7 @@ import os import tempfile from typing import List, Tuple +from typing_extensions import Dict import aiofiles import aiohttp @@ -10,11 +11,12 @@ from dvuploader.file import File from dvuploader.packaging import distribute_files, zip_files -from dvuploader.utils import build_url +from dvuploader.utils import build_url, retrieve_dataset_files MAX_RETRIES = os.environ.get("DVUPLOADER_MAX_RETRIES", 15) NATIVE_UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add" NATIVE_REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace" +NATIVE_METADATA_ENDPOINT = "/api/files/{FILE_ID}/metadata" assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer" @@ -74,6 +76,22 @@ async def native_upload( ] responses = await asyncio.gather(*tasks) + _validate_upload_responses(responses, files) + + await _update_metadata( + session=session, + files=files, + persistent_id=persistent_id, + dataverse_url=dataverse_url, + api_token=api_token, + ) + + +def _validate_upload_responses( + responses: List[Tuple], + files: List[File], +) -> None: + """Validates the responses of the native upload requests.""" for (status, response), file in zip(responses, files): if status == 200: @@ -174,20 +192,21 @@ async def _single_native_upload( endpoint=NATIVE_REPLACE_ENDPOINT.format(FILE_ID=file.file_id), ) - json_data = { - "description": file.description, - "forceReplace": True, - "directoryLabel": file.directory_label, - "categories": file.categories, - "restrict": file.restrict, - "forceReplace": True, - } + json_data = _get_json_data(file) for _ in range(MAX_RETRIES): formdata = aiohttp.FormData() - formdata.add_field("jsonData", json.dumps(json_data), content_type="application/json") - formdata.add_field("file", file.handler, filename=file.file_name) + formdata.add_field( + "jsonData", + json.dumps(json_data), + content_type="application/json", + ) + formdata.add_field( + "file", + file.handler, + filename=file.file_name, + ) async with session.post(endpoint, data=formdata) as response: status = response.status @@ -234,3 +253,127 @@ def file_sender( yield chunk chunk = file.handler.read(chunk_size) progress.advance(pbar, advance=chunk_size) + + +def _get_json_data(file: File) -> Dict: + """Returns the JSON data for the native upload request.""" + return { + "description": file.description, + "directoryLabel": file.directory_label, + "categories": file.categories, + "restrict": file.restrict, + "forceReplace": True, + } + + +async def _update_metadata( + session: aiohttp.ClientSession, + files: List[File], + dataverse_url: str, + api_token: str, + persistent_id: str, +): + """Updates the metadata of the given files in a Dataverse repository. + + Args: + + session (aiohttp.ClientSession): The aiohttp client session. + files (List[File]): The files to update the metadata for. + dataverse_url (str): The URL of the Dataverse repository. + api_token (str): The API token of the Dataverse repository. + persistent_id (str): The persistent identifier of the dataset. + """ + + file_mapping = _retrieve_file_ids( + persistent_id=persistent_id, + dataverse_url=dataverse_url, + api_token=api_token, + ) + + tasks = [] + + for file in files: + dv_path = os.path.join(file.directory_label, file.file_name) # type: ignore + + try: + file_id = file_mapping[dv_path] + except KeyError: + raise ValueError( + ( + f"File {dv_path} not found in Dataverse repository.", + "This may be due to the file not being uploaded to the repository.", + ) + ) + + task = _update_single_metadata( + session=session, + url=NATIVE_METADATA_ENDPOINT.format(FILE_ID=file_id), + file=file, + ) + + tasks.append(task) + + await asyncio.gather(*tasks) + + +async def _update_single_metadata( + session: aiohttp.ClientSession, + url: str, + file: File, +) -> None: + """Updates the metadata of a single file in a Dataverse repository.""" + + json_data = _get_json_data(file) + + del json_data["forceReplace"] + del json_data["restrict"] + + formdata = aiohttp.FormData() + formdata.add_field( + "jsonData", + json.dumps(json_data), + content_type="application/json", + ) + + async with session.post(url, data=formdata) as response: + response.raise_for_status() + + +def _retrieve_file_ids( + persistent_id: str, + dataverse_url: str, + api_token: str, +) -> Dict[str, str]: + """Retrieves the file IDs of the given files. + + Args: + files (List[File]): The files to retrieve the IDs for. + persistent_id (str): The persistent identifier of the dataset. + dataverse_url (str): The URL of the Dataverse repository. + api_token (str): The API token of the Dataverse repository. + + Returns: + Dict[str, str]: The list of file IDs. + """ + + # Fetch file metadata + ds_files = retrieve_dataset_files( + persistent_id=persistent_id, + dataverse_url=dataverse_url, + api_token=api_token, + ) + + return _create_file_id_path_mapping(ds_files) + + +def _create_file_id_path_mapping(files): + """Creates dictionary that maps from directoryLabel + filename to ID""" + mapping = {} + + for file in files: + directory_label = file.get("directoryLabel", "") + file = file["dataFile"] + path = os.path.join(directory_label, file["filename"]) + mapping[path] = file["id"] + + return mapping diff --git a/tests/integration/test_native_upload.py b/tests/integration/test_native_upload.py index 183cc54..ddeeb7d 100644 --- a/tests/integration/test_native_upload.py +++ b/tests/integration/test_native_upload.py @@ -106,7 +106,6 @@ def test_forced_native_upload( assert len(files) == 3 assert sorted([file["label"] for file in files]) == sorted(expected_files) - def test_native_upload_by_handler( self, credentials, @@ -116,8 +115,16 @@ def test_native_upload_by_handler( # Arrange byte_string = b"Hello, World!" files = [ - File(filepath="subdir/file.txt", handler=BytesIO(byte_string)), - File(filepath="biggerfile.txt", handler=BytesIO(byte_string*10000)), + File( + filepath="subdir/file.txt", + handler=BytesIO(byte_string), + description="This is a test", + ), + File( + filepath="biggerfile.txt", + handler=BytesIO(byte_string * 10000), + description="This is a test", + ), ] # Create Dataset @@ -154,5 +161,14 @@ def test_native_upload_by_handler( file = next(file for file in files if file["label"] == ex_f) - assert file["label"] == ex_f, f"File label does not match for file {json.dumps(file)}" - assert file.get("directoryLabel", "") == ex_dir, f"Directory label does not match for file {json.dumps(file)}" + assert ( + file["label"] == ex_f + ), f"File label does not match for file {json.dumps(file)}" + + assert ( + file.get("directoryLabel", "") == ex_dir + ), f"Directory label does not match for file {json.dumps(file)}" + + assert ( + file["description"] == "This is a test" + ), f"Description does not match for file {json.dumps(file)}"