diff --git a/dvuploader/directupload.py b/dvuploader/directupload.py index c270e0d..83a6d8f 100644 --- a/dvuploader/directupload.py +++ b/dvuploader/directupload.py @@ -8,15 +8,22 @@ import aiohttp from dvuploader.file import File -from dvuploader.nativeupload import file_sender from dvuploader.utils import build_url TESTING = bool(os.environ.get("DVUPLOADER_TESTING", False)) +MAX_FILE_DISPLAY = int(os.environ.get("DVUPLOADER_MAX_FILE_DISPLAY", 50)) +MAX_RETRIES = int(os.environ.get("DVUPLOADER_MAX_RETRIES", 10)) + +assert isinstance( + MAX_FILE_DISPLAY, int +), "DVUPLOADER_MAX_FILE_DISPLAY must be an integer" + +assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer" TICKET_ENDPOINT = "/api/datasets/:persistentId/uploadurls" ADD_FILE_ENDPOINT = "/api/datasets/:persistentId/addFiles" -UPLOAD_ENDPOINT = "/api/datasets/:persistentId/add?persistentId=" -REPLACE_ENDPOINT = "/api/files/{FILE_ID}/replace" +UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId=" +REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId=" async def direct_upload( @@ -44,6 +51,7 @@ async def direct_upload( None """ + leave_bar = len(files) < MAX_FILE_DISPLAY connector = aiohttp.TCPConnector(limit=n_parallel_uploads) async with aiohttp.ClientSession(connector=connector) as session: tasks = [ @@ -56,6 +64,7 @@ async def direct_upload( pbar=pbar, progress=progress, delay=0.0, + leave_bar=leave_bar, ) for pbar, file in zip(pbars, files) ] @@ -73,28 +82,20 @@ async def direct_upload( "x-amz-tagging": "dv-state=temp", } - connector = aiohttp.TCPConnector(limit=4) - pbar = progress.add_task("╰── [bold white]Registering files", total=len(files)) - results = [] + pbar = progress.add_task("╰── [bold white]Registering files", total=1) + connector = aiohttp.TCPConnector(limit=2) async with aiohttp.ClientSession( headers=headers, connector=connector, ) as session: - for file in files: - results.append( - await _add_file_to_ds( - session=session, - file=file, - dataverse_url=dataverse_url, - pid=persistent_id, - ) - ) - - progress.update(pbar, advance=1) - - for file, status in zip(files, results): - if status is False: - print(f"❌ Failed to register file '{file.fileName}' at Dataverse") + await _add_files_to_ds( + session=session, + files=files, + dataverse_url=dataverse_url, + pid=persistent_id, + progress=progress, + pbar=pbar, + ) async def _upload_to_store( @@ -106,6 +107,7 @@ async def _upload_to_store( pbar, progress, delay: float, + leave_bar: bool, ): """ Uploads a file to a Dataverse collection using direct upload. @@ -119,6 +121,7 @@ async def _upload_to_store( pbar: The progress bar object. progress: The progress object. delay (float): The delay in seconds before starting the upload. + leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete. Returns: tuple: A tuple containing the upload status (bool) and the file object. @@ -146,6 +149,7 @@ async def _upload_to_store( pbar=pbar, progress=progress, api_token=api_token, + leave_bar=leave_bar, ) else: @@ -207,6 +211,7 @@ async def _upload_singlepart( pbar, progress, api_token: str, + leave_bar: bool, ) -> Tuple[bool, str]: """ Uploads a single part of a file to a remote server using HTTP PUT method. @@ -217,6 +222,7 @@ async def _upload_singlepart( filepath (str): The path to the file to be uploaded. pbar (tqdm): A progress bar object to track the upload progress. progress: The progress object used to update the progress bar. + leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete. Returns: Tuple[bool, str]: A tuple containing the status of the upload (True for success, False for failure) @@ -235,11 +241,7 @@ async def _upload_singlepart( params = { "headers": headers, "url": ticket["url"], - "data": file_sender( - file_name=filepath, - progress=progress, - pbar=pbar, - ), + "data": open(filepath, "rb"), } async with session.put(**params) as response: @@ -247,7 +249,17 @@ async def _upload_singlepart( response.raise_for_status() if status: - progress.update(pbar, advance=os.path.getsize(filepath)) + progress.update( + pbar, + advance=os.path.getsize(filepath), + ) + + await asyncio.sleep(0.1) + + progress.update( + pbar, + visible=leave_bar, + ) return status, storage_identifier @@ -463,12 +475,14 @@ async def _abort_upload( response.raise_for_status() -async def _add_file_to_ds( +async def _add_files_to_ds( session: aiohttp.ClientSession, dataverse_url: str, pid: str, - file: File, -) -> bool: + files: List[File], + progress, + pbar, +) -> None: """ Adds a file to a Dataverse dataset. @@ -481,26 +495,77 @@ async def _add_file_to_ds( Returns: bool: True if the file was added successfully, False otherwise. """ - if not file.to_replace: - url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid) - else: - url = build_url( - dataverse_url=dataverse_url, - endpoint=urljoin( - dataverse_url, - REPLACE_ENDPOINT.format(FILE_ID=file.file_id), - ), - ) - json_data = file.model_dump_json( - by_alias=True, - exclude={"to_replace", "file_id"}, + novel_url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid) + replace_url = urljoin(dataverse_url, REPLACE_ENDPOINT + pid) + + novel_json_data = _prepare_registration(files, use_replace=False) + replace_json_data = _prepare_registration(files, use_replace=True) + + await _multipart_json_data_request( + session=session, + json_data=novel_json_data, + url=novel_url, + ) + + await _multipart_json_data_request( + session=session, + json_data=replace_json_data, + url=replace_url, + ) + + progress.update(pbar, advance=1) + + +def _prepare_registration(files: List[File], use_replace: bool) -> str: + """ + Prepares the files for registration at the Dataverse instance. + + Args: + files (List[File]): The list of files to prepare. + + Returns: + str: A JSON string containing the file data. + """ + + exclude = {"to_replace"} if use_replace else {"to_replace", "file_id"} + + return json.dumps( + [ + file.model_dump( + by_alias=True, + exclude=exclude, + exclude_none=True, + ) + for file in files + if file.to_replace is use_replace + ], indent=2, ) + +async def _multipart_json_data_request( + json_data: str, + url: str, + session: aiohttp.ClientSession, +): + """ + Sends a multipart/form-data POST request with JSON data to the specified URL using the provided session. + + Args: + json_data (str): The JSON data to be sent in the request body. + url (str): The URL to send the request to. + session (aiohttp.ClientSession): The aiohttp client session to use for the request. + + Raises: + aiohttp.ClientResponseError: If the response status code is not successful. + + Returns: + None + """ with aiohttp.MultipartWriter("form-data") as writer: json_part = writer.append(json_data) json_part.set_content_disposition("form-data", name="jsonData") async with session.post(url, data=writer) as response: - return response.status == 200 + response.raise_for_status() diff --git a/dvuploader/dvuploader.py b/dvuploader/dvuploader.py index ad1544a..b032a9d 100644 --- a/dvuploader/dvuploader.py +++ b/dvuploader/dvuploader.py @@ -9,6 +9,7 @@ from rich.progress import Progress, TaskID from rich.table import Table from rich.console import Console +from rich.panel import Panel from dvuploader.directupload import ( TICKET_ENDPOINT, @@ -40,6 +41,7 @@ def upload( dataverse_url: str, api_token: str, n_parallel_uploads: int = 1, + force_native: bool = False, ) -> None: """ Uploads the files to the specified Dataverse repository in parallel. @@ -53,7 +55,24 @@ def upload( Returns: None """ - # Validate and hash files + + print("\n") + info = "\n".join( + [ + f"Server: [bold]{dataverse_url}[/bold]", # type: ignore + f"PID: [bold]{persistent_id}[/bold]", # type: ignore + f"Files: {len(self.files)}", + ] + ) + + panel = Panel( + info, + title="[bold]DVUploader[/bold]", + expand=False, + ) + + rich.print(panel) + asyncio.run(self._validate_and_hash_files()) # Check for duplicates @@ -81,7 +100,7 @@ def upload( persistent_id=persistent_id, ) - if not has_direct_upload: + if not has_direct_upload and not force_native: rich.print( "\n[bold italic white]⚠️ Direct upload not supported. Falling back to Native API." ) @@ -90,7 +109,7 @@ def upload( progress, pbars = self.setup_progress_bars(files=files) - if not has_direct_upload: + if not has_direct_upload or force_native: with progress: asyncio.run( native_upload( @@ -127,15 +146,35 @@ async def _validate_and_hash_files(self): None """ - rich.print("\n[italic white]📝 Preparing upload\n") + print("\n") - tasks = [self._validate_and_hash_file(file=file) for file in self.files] + progress = Progress() + task = progress.add_task( + "[bold italic white]📦 Preparing upload[/bold italic white]", + total=len(self.files), + ) + with progress: + tasks = [ + self._validate_and_hash_file( + file=file, + progress=progress, + task_id=task, + ) + for file in self.files + ] + + await asyncio.gather(*tasks) - await asyncio.gather(*tasks) + print("\n") @staticmethod - async def _validate_and_hash_file(file: File): + async def _validate_and_hash_file( + file: File, + progress: Progress, + task_id: TaskID, + ): file.extract_filename_hash_file() + progress.update(task_id, advance=1) def _check_duplicates( self, @@ -169,6 +208,9 @@ def _check_duplicates( table.add_column("Action") to_remove = [] + over_threshold = len(self.files) > 50 + n_new_files = 0 + n_skip_files = 0 for file in self.files: has_same_hash = any( @@ -176,11 +218,13 @@ def _check_duplicates( ) if has_same_hash and file.checksum: + n_skip_files += 1 table.add_row( file.fileName, "[bright_black]Same hash", "[bright_black]Skip" ) to_remove.append(file) else: + n_new_files += 1 table.add_row( file.fileName, "[spring_green3]New", "[spring_green3]Upload" ) @@ -193,6 +237,14 @@ def _check_duplicates( self.files.remove(file) console = Console() + + if over_threshold: + table = Table(title="[bold white]🔎 Checking dataset files") + + table.add_column("New", style="spring_green3", no_wrap=True) + table.add_column("Skipped", style="bright_black", no_wrap=True) + table.add_row(str(n_new_files), str(n_skip_files)) + console.print(table) @staticmethod diff --git a/dvuploader/file.py b/dvuploader/file.py index f98216c..27f9ac2 100644 --- a/dvuploader/file.py +++ b/dvuploader/file.py @@ -1,7 +1,7 @@ import os -from typing import List, Optional +from typing import List, Optional, Union -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator import rich from dvuploader.checksum import Checksum, ChecksumTypes @@ -31,6 +31,8 @@ class File(BaseModel): """ + model_config: ConfigDict = ConfigDict(populate_by_alias=True) + filepath: str = Field(..., exclude=True) description: str = "" directoryLabel: str = "" @@ -42,7 +44,7 @@ class File(BaseModel): fileName: Optional[str] = None checksum: Optional[Checksum] = None to_replace: bool = False - file_id: Optional[str] = None + file_id: Optional[Union[str, int]] = Field(default=None, alias="fileToReplaceId") def extract_filename_hash_file(self): """ diff --git a/dvuploader/nativeupload.py b/dvuploader/nativeupload.py index b8e1da9..c24187f 100644 --- a/dvuploader/nativeupload.py +++ b/dvuploader/nativeupload.py @@ -49,6 +49,7 @@ async def native_upload( "headers": {"X-Dataverse-key": api_token}, "connector": aiohttp.TCPConnector( limit=n_parallel_uploads, + timeout_ceil_threshold=120, ), } diff --git a/dvuploader/packaging.py b/dvuploader/packaging.py index cc55289..2a45f6f 100644 --- a/dvuploader/packaging.py +++ b/dvuploader/packaging.py @@ -6,7 +6,7 @@ MAXIMUM_PACKAGE_SIZE = int( os.environ.get( "DVUPLOADER_MAX_PKG_SIZE", - 1024**3, # 1 GB + 2 * 1024**3, # 2 GB ) ) diff --git a/poetry.lock b/poetry.lock index e7fa9df..21e40e9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1315,6 +1315,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, diff --git a/tests/integration/test_native_upload.py b/tests/integration/test_native_upload.py index 31b7a49..923936e 100644 --- a/tests/integration/test_native_upload.py +++ b/tests/integration/test_native_upload.py @@ -53,3 +53,50 @@ def test_native_upload( assert len(files) == 3 assert sorted([file["label"] for file in files]) == sorted(expected_files) + + def test_forced_native_upload( + self, + credentials, + ): + BASE_URL, API_TOKEN = credentials + + with tempfile.TemporaryDirectory() as directory: + # Arrange + create_mock_file(directory, "small_file.txt", size=1) + create_mock_file(directory, "mid_file.txt", size=50) + create_mock_file(directory, "large_file.txt", size=200) + + # Add all files in the directory + files = add_directory(directory=directory) + + # Create Dataset + pid = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + ) + + # Act + uploader = DVUploader(files=files) + uploader.upload( + persistent_id=pid, + api_token=API_TOKEN, + dataverse_url=BASE_URL, + n_parallel_uploads=1, + force_native=True, + ) + + # Assert + expected_files = [ + "small_file.txt", + "mid_file.txt", + "large_file.txt", + ] + files = retrieve_dataset_files( + dataverse_url=BASE_URL, + persistent_id=pid, + api_token=API_TOKEN, + ) + + assert len(files) == 3 + assert sorted([file["label"] for file in files]) == sorted(expected_files) diff --git a/tests/unit/test_directupload.py b/tests/unit/test_directupload.py index 658f290..415dbf2 100644 --- a/tests/unit/test_directupload.py +++ b/tests/unit/test_directupload.py @@ -1,10 +1,11 @@ from urllib.parse import urljoin import aiohttp import pytest +from rich.progress import Progress from dvuploader.directupload import ( - _add_file_to_ds, UPLOAD_ENDPOINT, REPLACE_ENDPOINT, + _add_files_to_ds, _validate_ticket_response, ) @@ -24,16 +25,24 @@ async def test_successfully_add_file_with_valid_filepath(self, mocker): dataverse_url = "https://example.com" pid = "persistent_id" fpath = "tests/fixtures/add_dir_files/somefile.txt" - file = File(filepath=fpath) + files = [File(filepath=fpath)] + progress = Progress() + pbar = progress.add_task("Uploading", total=1) # Invoke the function - result = await _add_file_to_ds(session, dataverse_url, pid, file) + await _add_files_to_ds( + session=session, + dataverse_url=dataverse_url, + pid=pid, + files=files, + progress=progress, + pbar=pbar, + ) # Assert that the response status is 200 and the result is True assert mock_post.called_with( urljoin(dataverse_url, UPLOAD_ENDPOINT + pid), data=mocker.ANY ) - assert result is True @pytest.mark.asyncio async def test_successfully_replace_file_with_valid_filepath(self, mocker): @@ -46,17 +55,25 @@ async def test_successfully_replace_file_with_valid_filepath(self, mocker): dataverse_url = "https://example.com" pid = "persistent_id" fpath = "tests/fixtures/add_dir_files/somefile.txt" - file = File(filepath=fpath, file_id="0") + files = [File(filepath=fpath, file_id="0")] + progress = Progress() + pbar = progress.add_task("Uploading", total=1) # Invoke the function - result = await _add_file_to_ds(session, dataverse_url, pid, file) + await _add_files_to_ds( + session=session, + dataverse_url=dataverse_url, + pid=pid, + files=files, + progress=progress, + pbar=pbar, + ) # Assert that the response status is 200 and the result is True assert mock_post.called_with( - urljoin(dataverse_url, REPLACE_ENDPOINT.format(FILE_ID=file.file_id)), + urljoin(dataverse_url, REPLACE_ENDPOINT + pid), data=mocker.ANY, ) - assert result is True class Test_ValidateTicketResponse: