Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions dvuploader/dvuploader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from urllib.parse import urljoin
import requests
import os
Expand All @@ -19,6 +20,7 @@
from dvuploader.nativeupload import native_upload
from dvuploader.utils import build_url, retrieve_dataset_files, setup_pbar


class DVUploader(BaseModel):
"""
A class for uploading files to a Dataverse repository.
Expand Down Expand Up @@ -153,10 +155,7 @@ async def _validate_and_hash_files(self, verbose: bool):

if not verbose:
tasks = [
self._validate_and_hash_file(
file=file,
verbose=self.verbose
)
self._validate_and_hash_file(file=file, verbose=self.verbose)
for file in self.files
]

Expand All @@ -175,10 +174,7 @@ async def _validate_and_hash_files(self, verbose: bool):

tasks = [
self._validate_and_hash_file(
file=file,
progress=progress,
task_id=task,
verbose=self.verbose
file=file, progress=progress, task_id=task, verbose=self.verbose
)
for file in self.files
]
Expand All @@ -197,7 +193,7 @@ async def _validate_and_hash_file(
file.extract_file_name_hash_file()

if verbose:
progress.update(task_id, advance=1) # type: ignore
progress.update(task_id, advance=1) # type: ignore

def _check_duplicates(
self,
Expand Down Expand Up @@ -240,7 +236,7 @@ def _check_duplicates(
map(lambda dsFile: self._check_hashes(file, dsFile), ds_files)
)

if has_same_hash and file.checksum:
if has_same_hash:
n_skip_files += 1
table.add_row(
file.file_name, "[bright_black]Same hash", "[bright_black]Skip"
Expand Down Expand Up @@ -316,12 +312,14 @@ def _check_hashes(file: File, dsFile: Dict):
return False

hash_algo, hash_value = tuple(dsFile["dataFile"]["checksum"].values())
path = os.path.join(
dsFile.get("directoryLabel", ""), dsFile["dataFile"]["filename"]
)

return (
file.checksum.value == hash_value
and file.checksum.type == hash_algo
and file.file_name == dsFile["label"]
and file.directory_label == dsFile.get("directoryLabel", "")
and path == os.path.join(file.directory_label, file.file_name) # type: ignore
)

@staticmethod
Expand Down
165 changes: 154 additions & 11 deletions dvuploader/nativeupload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
import os
import tempfile
from typing import List, Tuple
from typing_extensions import Dict
import aiofiles
import aiohttp

from rich.progress import Progress, TaskID

from dvuploader.file import File
from dvuploader.packaging import distribute_files, zip_files
from dvuploader.utils import build_url
from dvuploader.utils import build_url, retrieve_dataset_files

MAX_RETRIES = os.environ.get("DVUPLOADER_MAX_RETRIES", 15)
NATIVE_UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add"
NATIVE_REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace"
NATIVE_METADATA_ENDPOINT = "/api/files/{FILE_ID}/metadata"

assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer"

Expand Down Expand Up @@ -74,6 +76,22 @@ async def native_upload(
]

responses = await asyncio.gather(*tasks)
_validate_upload_responses(responses, files)

await _update_metadata(
session=session,
files=files,
persistent_id=persistent_id,
dataverse_url=dataverse_url,
api_token=api_token,
)


def _validate_upload_responses(
responses: List[Tuple],
files: List[File],
) -> None:
"""Validates the responses of the native upload requests."""

for (status, response), file in zip(responses, files):
if status == 200:
Expand Down Expand Up @@ -174,20 +192,21 @@ async def _single_native_upload(
endpoint=NATIVE_REPLACE_ENDPOINT.format(FILE_ID=file.file_id),
)

json_data = {
"description": file.description,
"forceReplace": True,
"directoryLabel": file.directory_label,
"categories": file.categories,
"restrict": file.restrict,
"forceReplace": True,
}
json_data = _get_json_data(file)

for _ in range(MAX_RETRIES):

formdata = aiohttp.FormData()
formdata.add_field("jsonData", json.dumps(json_data), content_type="application/json")
formdata.add_field("file", file.handler, filename=file.file_name)
formdata.add_field(
"jsonData",
json.dumps(json_data),
content_type="application/json",
)
formdata.add_field(
"file",
file.handler,
filename=file.file_name,
)

async with session.post(endpoint, data=formdata) as response:
status = response.status
Expand Down Expand Up @@ -234,3 +253,127 @@ def file_sender(
yield chunk
chunk = file.handler.read(chunk_size)
progress.advance(pbar, advance=chunk_size)


def _get_json_data(file: File) -> Dict:
"""Returns the JSON data for the native upload request."""
return {
"description": file.description,
"directoryLabel": file.directory_label,
"categories": file.categories,
"restrict": file.restrict,
"forceReplace": True,
}


async def _update_metadata(
session: aiohttp.ClientSession,
files: List[File],
dataverse_url: str,
api_token: str,
persistent_id: str,
):
"""Updates the metadata of the given files in a Dataverse repository.

Args:

session (aiohttp.ClientSession): The aiohttp client session.
files (List[File]): The files to update the metadata for.
dataverse_url (str): The URL of the Dataverse repository.
api_token (str): The API token of the Dataverse repository.
persistent_id (str): The persistent identifier of the dataset.
"""

file_mapping = _retrieve_file_ids(
persistent_id=persistent_id,
dataverse_url=dataverse_url,
api_token=api_token,
)

tasks = []

for file in files:
dv_path = os.path.join(file.directory_label, file.file_name) # type: ignore

try:
file_id = file_mapping[dv_path]
except KeyError:
raise ValueError(
(
f"File {dv_path} not found in Dataverse repository.",
"This may be due to the file not being uploaded to the repository.",
)
)

task = _update_single_metadata(
session=session,
url=NATIVE_METADATA_ENDPOINT.format(FILE_ID=file_id),
file=file,
)

tasks.append(task)

await asyncio.gather(*tasks)


async def _update_single_metadata(
session: aiohttp.ClientSession,
url: str,
file: File,
) -> None:
"""Updates the metadata of a single file in a Dataverse repository."""

json_data = _get_json_data(file)

del json_data["forceReplace"]
del json_data["restrict"]

formdata = aiohttp.FormData()
formdata.add_field(
"jsonData",
json.dumps(json_data),
content_type="application/json",
)

async with session.post(url, data=formdata) as response:
response.raise_for_status()


def _retrieve_file_ids(
persistent_id: str,
dataverse_url: str,
api_token: str,
) -> Dict[str, str]:
"""Retrieves the file IDs of the given files.

Args:
files (List[File]): The files to retrieve the IDs for.
persistent_id (str): The persistent identifier of the dataset.
dataverse_url (str): The URL of the Dataverse repository.
api_token (str): The API token of the Dataverse repository.

Returns:
Dict[str, str]: The list of file IDs.
"""

# Fetch file metadata
ds_files = retrieve_dataset_files(
persistent_id=persistent_id,
dataverse_url=dataverse_url,
api_token=api_token,
)

return _create_file_id_path_mapping(ds_files)


def _create_file_id_path_mapping(files):
"""Creates dictionary that maps from directoryLabel + filename to ID"""
mapping = {}

for file in files:
directory_label = file.get("directoryLabel", "")
file = file["dataFile"]
path = os.path.join(directory_label, file["filename"])
mapping[path] = file["id"]

return mapping
26 changes: 21 additions & 5 deletions tests/integration/test_native_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def test_forced_native_upload(
assert len(files) == 3
assert sorted([file["label"] for file in files]) == sorted(expected_files)


def test_native_upload_by_handler(
self,
credentials,
Expand All @@ -116,8 +115,16 @@ def test_native_upload_by_handler(
# Arrange
byte_string = b"Hello, World!"
files = [
File(filepath="subdir/file.txt", handler=BytesIO(byte_string)),
File(filepath="biggerfile.txt", handler=BytesIO(byte_string*10000)),
File(
filepath="subdir/file.txt",
handler=BytesIO(byte_string),
description="This is a test",
),
File(
filepath="biggerfile.txt",
handler=BytesIO(byte_string * 10000),
description="This is a test",
),
]

# Create Dataset
Expand Down Expand Up @@ -154,5 +161,14 @@ def test_native_upload_by_handler(

file = next(file for file in files if file["label"] == ex_f)

assert file["label"] == ex_f, f"File label does not match for file {json.dumps(file)}"
assert file.get("directoryLabel", "") == ex_dir, f"Directory label does not match for file {json.dumps(file)}"
assert (
file["label"] == ex_f
), f"File label does not match for file {json.dumps(file)}"

assert (
file.get("directoryLabel", "") == ex_dir
), f"Directory label does not match for file {json.dumps(file)}"

assert (
file["description"] == "This is a test"
), f"Description does not match for file {json.dumps(file)}"