diff --git a/dvuploader/checksum.py b/dvuploader/checksum.py index d6a11dd..b5a3a03 100644 --- a/dvuploader/checksum.py +++ b/dvuploader/checksum.py @@ -1,12 +1,12 @@ import hashlib from enum import Enum from typing import IO, Callable +from pydantic.fields import PrivateAttr +from typing_extensions import Optional from pydantic import BaseModel, ConfigDict, Field - - class ChecksumTypes(Enum): """Enum class representing different types of checksums. @@ -24,11 +24,15 @@ class ChecksumTypes(Enum): class Checksum(BaseModel): - """Checksum class represents a checksum object with type and value fields. + """Class for calculating and storing file checksums. + + This class handles checksum calculation and storage for files being uploaded to Dataverse. + It supports multiple hash algorithms through the ChecksumTypes enum. Attributes: - type (str): The type of the checksum. - value (str): The value of the checksum. + type (str): The type of checksum algorithm being used (e.g. "SHA-1", "MD5") + value (Optional[str]): The calculated checksum value, or None if not yet calculated + _hash_fun (PrivateAttr): Internal hash function instance used for calculation """ model_config = ConfigDict( @@ -37,44 +41,58 @@ class Checksum(BaseModel): ) type: str = Field(..., alias="@type") - value: str = Field(..., alias="@value") + value: Optional[str] = Field(None, alias="@value") + _hash_fun = PrivateAttr(default=None) @classmethod - def from_file( + def from_algo( cls, - handler: IO, hash_fun: Callable, hash_algo: str, ) -> "Checksum": - """Takes a file path and returns a checksum object. + """Creates a new Checksum instance configured for a specific hash algorithm. Args: - handler (IO): The file handler to generate the checksum for. - hash_fun (Callable): The hash function to use for generating the checksum. - hash_algo (str): The hash algorithm to use for generating the checksum. + hash_fun (Callable): Hash function constructor (e.g. hashlib.sha1) + hash_algo (str): Name of the hash algorithm (e.g. "SHA-1") Returns: - Checksum: A Checksum object with type and value fields. + Checksum: A new Checksum instance ready for calculating checksums + """ + + cls = cls(type=hash_algo, value=None) # type: ignore + cls._hash_fun = hash_fun() + + return cls + + def apply_checksum(self): + """Finalizes and stores the calculated checksum value. + + This should be called after all data has been processed through the hash function. + The resulting checksum is stored in the value attribute. + + Raises: + AssertionError: If the hash function has not been initialized """ - value = cls._chunk_checksum(handler=handler, hash_fun=hash_fun) - return cls(type=hash_algo, value=value) # type: ignore + assert self._hash_fun is not None, "Checksum hash function is not set." + + self.value = self._hash_fun.hexdigest() @staticmethod - def _chunk_checksum( - handler: IO, - hash_fun: Callable, - blocksize=2**20 - ) -> str: - """Chunks a file and returns a checksum. + def _chunk_checksum(handler: IO, hash_fun: Callable, blocksize=2**20) -> str: + """Calculates a file's checksum by processing it in chunks. Args: - fpath (str): The file path to generate the checksum for. - hash_fun (Callable): The hash function to use for generating the checksum. - blocksize (int): The block size to use for reading the file. + handler (IO): File-like object to read data from + hash_fun (Callable): Hash function constructor to use + blocksize (int, optional): Size of chunks to read. Defaults to 1MB (2**20) Returns: - str: A string representing the checksum of the file. + str: Hexadecimal string representation of the calculated checksum + + Note: + This method resets the file position to the start after reading """ m = hash_fun() while True: diff --git a/dvuploader/cli.py b/dvuploader/cli.py index 1086011..ca74bd1 100644 --- a/dvuploader/cli.py +++ b/dvuploader/cli.py @@ -7,6 +7,17 @@ class CliInput(BaseModel): + """ + Model for CLI input parameters. + + Attributes: + api_token (str): API token for authentication with Dataverse + dataverse_url (str): URL of the Dataverse instance + persistent_id (str): Persistent identifier of the dataset + files (List[File]): List of files to upload + n_jobs (int): Number of parallel upload jobs to run (default: 1) + """ + api_token: str dataverse_url: str persistent_id: str @@ -19,19 +30,17 @@ class CliInput(BaseModel): def _parse_yaml_config(path: str) -> CliInput: """ - Parses a configuration file and returns a Class instance - containing a list of File objects, a persistent ID, a Dataverse URL, - and an API token. + Parse a YAML/JSON configuration file into a CliInput object. Args: - path (str): Path to a JSON/YAML file containing specifications for the files to upload. + path (str): Path to a YAML/JSON configuration file containing upload specifications Returns: - CliInput: Class instance containing a list of File objects, a persistent ID, - a Dataverse URL, and an API token. + CliInput: Object containing upload configuration parameters Raises: - ValueError: If the configuration file is invalid. + yaml.YAMLError: If the YAML/JSON file is malformed + ValidationError: If the configuration data does not match the CliInput model """ return CliInput(**yaml.safe_load(open(path))) # type: ignore @@ -44,18 +53,20 @@ def _validate_inputs( config_path: Optional[str], ) -> None: """ - Validates the inputs for the dvuploader command. + Validate CLI input parameters. + + Checks for valid combinations of configuration file and command line parameters. Args: - filepaths (List[str]): List of filepaths to be uploaded. - pid (str): Persistent identifier of the dataset. - dataverse_url (str): URL of the Dataverse instance. - api_token (str): API token for authentication. - config_path (Optional[str]): Path to the configuration file. + filepaths (List[str]): List of files to upload + pid (str): Persistent identifier of the dataset + dataverse_url (str): URL of the Dataverse instance + api_token (str): API token for authentication + config_path (Optional[str]): Path to configuration file Raises: - typer.BadParameter: If both a configuration file and a list of filepaths are specified. - typer.BadParameter: If neither a configuration file nor metadata parameters are specified. + typer.BadParameter: If both config file and filepaths are specified + typer.BadParameter: If neither config file nor required parameters are provided """ if config_path is not None and len(filepaths) > 0: raise typer.BadParameter( @@ -97,25 +108,39 @@ def main( ), config_path: Optional[str] = typer.Option( default=None, - help="Path to a JSON/YAML file containing specifications for the files to upload. Defaults to None.", + help="Path to a JSON/YAML file containing specifications for the files to upload.", ), n_jobs: int = typer.Option( default=1, - help="The number of parallel jobs to run. Defaults to -1.", + help="Number of parallel upload jobs to run.", ), ): """ - Uploads files to a Dataverse repository. - - Args: - filepaths (List[str]): A list of filepaths to upload. - pid (str): The persistent identifier of the Dataverse dataset. - api_token (str): The API token for the Dataverse repository. - dataverse_url (str): The URL of the Dataverse repository. - config_path (Optional[str]): Path to a JSON/YAML file containing specifications for the files to upload. Defaults to None. - n_jobs (int): The number of parallel jobs to run. Defaults to -1. + Upload files to a Dataverse repository. + + Files can be specified either directly via command line arguments or through a + configuration file. The configuration file can be either YAML or JSON format. + + If using command line arguments, you must specify: + - One or more filepaths to upload + - The dataset's persistent identifier + - A valid API token + - The Dataverse repository URL + + If using a configuration file, it should contain: + - api_token: API token for authentication + - dataverse_url: URL of the Dataverse instance + - persistent_id: Dataset persistent identifier + - files: List of file specifications + - n_jobs: (optional) Number of parallel upload jobs + + Examples: + Upload files via command line: + $ dvuploader file1.txt file2.txt --pid doi:10.5072/FK2/123456 --api-token abc123 --dataverse-url https://demo.dataverse.org + + Upload files via config file: + $ dvuploader --config-path upload_config.yaml """ - _validate_inputs( filepaths=filepaths, pid=pid, diff --git a/dvuploader/directupload.py b/dvuploader/directupload.py index f198476..c2b84b3 100644 --- a/dvuploader/directupload.py +++ b/dvuploader/directupload.py @@ -6,6 +6,8 @@ from typing import Dict, List, Optional, Tuple from urllib.parse import urljoin import aiofiles +from typing import AsyncGenerator +from rich.progress import Progress, TaskID from dvuploader.file import File from dvuploader.utils import build_url @@ -14,9 +16,9 @@ MAX_FILE_DISPLAY = int(os.environ.get("DVUPLOADER_MAX_FILE_DISPLAY", 50)) MAX_RETRIES = int(os.environ.get("DVUPLOADER_MAX_RETRIES", 10)) -assert isinstance( - MAX_FILE_DISPLAY, int -), "DVUPLOADER_MAX_FILE_DISPLAY must be an integer" +assert isinstance(MAX_FILE_DISPLAY, int), ( + "DVUPLOADER_MAX_FILE_DISPLAY must be an integer" +) assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer" @@ -41,11 +43,11 @@ async def direct_upload( Args: files (List[File]): A list of File objects to be uploaded. dataverse_url (str): The URL of the Dataverse repository. - api_token (str): The API token for the Dataverse repository. - persistent_id (str): The persistent identifier of the Dataverse dataset. - progress: The progress object to track the upload progress. - pbars: A list of progress bars to display the upload progress for each file. - n_parallel_uploads (int): The number of parallel uploads to perform. + api_token (str): The API token for authentication. + persistent_id (str): The persistent identifier of the dataset. + progress: Progress object to track upload progress. + pbars: List of progress bars for each file. + n_parallel_uploads (int): Number of concurrent uploads to perform. Returns: None @@ -118,21 +120,21 @@ async def _upload_to_store( leave_bar: bool, ): """ - Uploads a file to a Dataverse collection using direct upload. + Upload a file to Dataverse storage using direct upload. Args: - session (httpx.AsyncClient): The httpx async client session. - file (File): The file object to upload. - persistent_id (str): The persistent identifier of the Dataverse dataset to upload to. - dataverse_url (str): The URL of the Dataverse instance to upload to. - api_token (str): The API token to use for authentication. - pbar: The progress bar object. - progress: The progress object. - delay (float): The delay in seconds before starting the upload. - leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete. + session (httpx.AsyncClient): Async HTTP client session. + file (File): File object to upload. + persistent_id (str): Dataset persistent identifier. + dataverse_url (str): Dataverse instance URL. + api_token (str): API token for authentication. + pbar: Progress bar for this file. + progress: Progress tracking object. + delay (float): Delay before starting upload in seconds. + leave_bar (bool): Whether to keep progress bar after completion. Returns: - tuple: A tuple containing the upload status (bool) and the file object. + tuple: (success: bool, file: File) indicating upload status and file object. """ await asyncio.sleep(delay) @@ -179,21 +181,18 @@ async def _request_ticket( persistent_id: str, file_size: int, ) -> Dict: - """Requests a ticket from a Dataverse collection to perform an upload. - - This method will send a request to the Dataverse API to obtain a ticket - for performing a direct upload to an S3 bucket. The ticket contains a URL - and storageIdentifier that will be used later to perform the upload. + """ + Request an upload ticket from Dataverse. Args: - session (httpx.AsyncClient): The httpx async client session to use for the request. - dataverse_url (str): The URL of the Dataverse installation. - api_token (str): The API token used to access the dataset. - persistent_id (str): The persistent identifier of the dataset of interest. - file_size (int): The size of the file to be uploaded. + session (httpx.AsyncClient): Async HTTP client session. + dataverse_url (str): Dataverse instance URL. + api_token (str): API token for authentication. + persistent_id (str): Dataset persistent identifier. + file_size (int): Size of file to upload in bytes. Returns: - Dict: The response from the Dataverse API, containing the ticket information. + Dict: Upload ticket containing URL and storage identifier. """ url = build_url( endpoint=urljoin(dataverse_url, TICKET_ENDPOINT), @@ -202,7 +201,7 @@ async def _request_ticket( size=file_size, ) - response = await session.get(url) + response = await session.get(url, timeout=None) response.raise_for_status() return response.json()["data"] @@ -218,21 +217,22 @@ async def _upload_singlepart( leave_bar: bool, ) -> Tuple[bool, str]: """ - Uploads a single part of a file to a remote server using HTTP PUT method. + Upload a file in a single request. Args: - session (httpx.AsyncClient): The httpx async client session used for the upload. - ticket (Dict): A dictionary containing the response from the server. - filepath (str): The path to the file to be uploaded. - pbar (tqdm): A progress bar object to track the upload progress. - progress: The progress object used to update the progress bar. - leave_bar (bool): A flag indicating whether to keep the progress bar visible after the upload is complete. + session (httpx.AsyncClient): Async HTTP client session. + ticket (Dict): Upload ticket from Dataverse. + file (File): File object to upload. + pbar: Progress bar for this file. + progress: Progress tracking object. + api_token (str): API token for authentication. + leave_bar (bool): Whether to keep progress bar after completion. Returns: - Tuple[bool, str]: A tuple containing the status of the upload (True for success, False for failure) - and the storage identifier of the uploaded file. + Tuple[bool, str]: (success status, storage identifier) """ assert "url" in ticket, "Couldn't find 'url'" + assert file.checksum is not None, "Checksum is required for singlepart uploads" if TESTING: ticket["url"] = ticket["url"].replace("localstack", "localhost", 1) @@ -240,17 +240,26 @@ async def _upload_singlepart( headers = { "X-Dataverse-key": api_token, "x-amz-tagging": "dv-state=temp", + "Content-length": str(file._size), } + storage_identifier = ticket["storageIdentifier"] params = { "headers": headers, "url": ticket["url"], - "files": {"": file.handler}, + "content": upload_bytes( + file=file.handler, # type: ignore + progress=progress, + pbar=pbar, + hash_func=file.checksum._hash_fun, + ), } response = await session.put(**params) response.raise_for_status() + file.apply_checksum() + if response.status_code == 200: progress.update(pbar, advance=file._size) await asyncio.sleep(0.1) @@ -272,18 +281,19 @@ async def _upload_multipart( api_token: str, ): """ - Uploads a file to Dataverse using multipart upload. + Upload a file using multipart upload. Args: - session (httpx.AsyncClient): The httpx async client session. - response (Dict): The response from the Dataverse API containing the upload ticket information. - file (File): The file object to be uploaded. - dataverse_url (str): The URL of the Dataverse instance. - pbar (tqdm): A progress bar to track the upload progress. - progress: The progress callback function. + session (httpx.AsyncClient): Async HTTP client session. + response (Dict): Upload ticket response from Dataverse. + file (File): File object to upload. + dataverse_url (str): Dataverse instance URL. + pbar: Progress bar for this file. + progress: Progress tracking object. + api_token (str): API token for authentication. Returns: - Tuple[bool, str]: A tuple containing a boolean indicating the success of the upload and the storage identifier for the uploaded file. + Tuple[bool, str]: (success status, storage identifier) """ _validate_ticket_response(response) @@ -324,6 +334,9 @@ async def _upload_multipart( api_token=api_token, ) + file.apply_checksum() + print(file.checksum) + return True, storage_identifier @@ -336,19 +349,21 @@ async def _chunked_upload( progress, ): """ - Uploads a file in chunks to multiple URLs using the provided session. + Upload a file in chunks. Args: - file (File): The file object to upload. - session (httpx.AsyncClient): The httpx async client session to use for the upload. - urls: An iterable of URLs to upload the file chunks to. - chunk_size (int): The size of each chunk in bytes. - pbar (tqdm): The progress bar to update during the upload. - progress: The progress object to track the upload progress. + file (File): File object to upload. + session (httpx.AsyncClient): Async HTTP client session. + urls: Iterator of upload URLs for each chunk. + chunk_size (int): Size of each chunk in bytes. + pbar: Progress bar for this file. + progress: Progress tracking object. Returns: - List[str]: A list of ETags returned by the server for each uploaded chunk. + List[str]: ETags returned by server for each chunk. """ + assert file.checksum is not None, "Checksum is required for multipart uploads" + e_tags = [] if not os.path.exists(file.filepath): @@ -367,14 +382,14 @@ async def _chunked_upload( session=session, url=next(urls), file=BytesIO(chunk), + progress=progress, + pbar=pbar, + hash_func=file.checksum._hash_fun, ) ) - progress.update(pbar, advance=len(chunk)) - while chunk: chunk = await f.read(chunk_size) - progress.update(pbar, advance=len(chunk)) if not chunk: break @@ -384,6 +399,9 @@ async def _chunked_upload( session=session, url=next(urls), file=BytesIO(chunk), + progress=progress, + pbar=pbar, + hash_func=file.checksum._hash_fun, ) ) @@ -391,7 +409,15 @@ async def _chunked_upload( def _validate_ticket_response(response: Dict) -> None: - """Validate the response from the ticket request to include all necessary fields.""" + """ + Validate that upload ticket response contains required fields. + + Args: + response (Dict): Upload ticket response to validate. + + Raises: + AssertionError: If required fields are missing. + """ assert "abort" in response, "Couldn't find 'abort'" assert "complete" in response, "Couldn't find 'complete'" @@ -404,26 +430,41 @@ async def _upload_chunk( session: httpx.AsyncClient, url: str, file: BytesIO, + progress: Progress, + pbar: TaskID, + hash_func, ): """ - Uploads a chunk of data to the specified URL using the provided session. + Upload a single chunk of data. Args: - session (httpx.AsyncClient): The session to use for the upload. - url (str): The URL to upload the chunk to. - file (ChunkStream): The chunk of data to upload. - pbar: The progress bar to update during the upload. + session (httpx.AsyncClient): Async HTTP client session. + url (str): URL to upload chunk to. + file (BytesIO): Chunk data to upload. + progress (Progress): Progress tracking object. + pbar (TaskID): Progress bar task ID. + hash_func: Hash function for checksum. Returns: - str: The ETag value of the uploaded chunk. + str: ETag from server response. """ if TESTING: url = url.replace("localstack", "localhost", 1) + headers = { + "Content-length": str(len(file.getvalue())), + } + params = { + "headers": headers, "url": url, - "data": file, + "data": upload_bytes( + file=file, + progress=progress, + pbar=pbar, + hash_func=hash_func, + ), } response = await session.put(**params) @@ -439,16 +480,15 @@ async def _complete_upload( e_tags: List[Optional[str]], api_token: str, ) -> None: - """Completes the upload by sending the E tags + """ + Complete a multipart upload by sending ETags. Args: - session (httpx.AsyncClient): The aiohttp client session. - url (str): The URL to send the PUT request to. - dataverse_url (str): The base URL of the Dataverse instance. - e_tags (List[str]): The list of E tags to send in the payload. - - Raises: - aiohttp.ClientResponseError: If the response status code is not successful. + session (httpx.AsyncClient): Async HTTP client session. + url (str): URL to send completion request to. + dataverse_url (str): Dataverse instance URL. + e_tags (List[str]): List of ETags from uploaded chunks. + api_token (str): API token for authentication. """ payload = json.dumps({str(index + 1): e_tag for index, e_tag in enumerate(e_tags)}) @@ -472,16 +512,13 @@ async def _abort_upload( api_token: str, ): """ - Aborts an ongoing upload by sending a DELETE request to the specified URL. + Abort an in-progress multipart upload. Args: - session (httpx.AsyncClient): The httpx async client session. - url (str): The URL to send the DELETE request to. - dataverse_url (str): The base URL of the Dataverse instance. - api_token (str): The API token to use for the request. - - Raises: - aiohttp.ClientResponseError: If the DELETE request fails. + session (httpx.AsyncClient): Async HTTP client session. + url (str): URL to send abort request to. + dataverse_url (str): Dataverse instance URL. + api_token (str): API token for authentication. """ headers = {"X-Dataverse-key": api_token} @@ -500,16 +537,15 @@ async def _add_files_to_ds( pbar, ) -> None: """ - Adds a file to a Dataverse dataset. + Register uploaded files with the dataset. Args: - session (httpx.AsyncClient): The httpx async client session. - dataverse_url (str): The URL of the Dataverse instance. - pid (str): The persistent identifier of the dataset. - file (File): The file to be added. - - Returns: - bool: True if the file was added successfully, False otherwise. + session (httpx.AsyncClient): Async HTTP client session. + dataverse_url (str): Dataverse instance URL. + pid (str): Dataset persistent identifier. + files (List[File]): List of uploaded files to register. + progress: Progress tracking object. + pbar: Progress bar for registration. """ novel_url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid) @@ -535,13 +571,14 @@ async def _add_files_to_ds( def _prepare_registration(files: List[File], use_replace: bool) -> List[Dict]: """ - Prepares the files for registration at the Dataverse instance. + Prepare file metadata for registration. Args: - files (List[File]): The list of files to prepare. + files (List[File]): List of files to prepare metadata for. + use_replace (bool): Whether these are replacement files. Returns: - List[Dict]: The list of files prepared for registration. + List[Dict]: List of file metadata dictionaries. """ exclude = {"to_replace"} if use_replace else {"to_replace", "file_id"} @@ -563,18 +600,15 @@ async def _multipart_json_data_request( session: httpx.AsyncClient, ): """ - Sends a multipart/form-data POST request with JSON data to the specified URL using the provided session. + Send multipart form request with JSON data. Args: - json_data (str): The JSON data to be sent in the request body. - url (str): The URL to send the request to. - session (httpx.AsyncClient): The httpx async client session to use for the request. + json_data (List[Dict]): JSON data to send. + url (str): URL to send request to. + session (httpx.AsyncClient): Async HTTP client session. Raises: - httpx.HTTPStatusError: If the response status code is not successful. - - Returns: - None + httpx.HTTPStatusError: If request fails. """ files = { @@ -586,4 +620,43 @@ async def _multipart_json_data_request( } response = await session.post(url, files=files) - response.raise_for_status() + + if not response.is_success: + raise httpx.HTTPStatusError( + f"Failed to register files: {response.text}", + request=response.request, + response=response, + ) + + +async def upload_bytes( + file: BytesIO, + progress: Progress, + pbar: TaskID, + hash_func, +) -> AsyncGenerator[bytes, None]: + """ + Generate chunks of file data for upload. + + Args: + file (BytesIO): File to read chunks from. + progress (Progress): Progress tracking object. + pbar (TaskID): Progress bar task ID. + hash_func: Hash function for checksum. + + Yields: + bytes: Next chunk of file data. + """ + while True: + data = file.read(1024 * 1024) # 1MB + + if not data: + break + + # Update the hash function with the data + hash_func.update(data) + + # Update the progress bar + progress.update(pbar, advance=len(data)) + + yield data diff --git a/dvuploader/dvuploader.py b/dvuploader/dvuploader.py index 0b0052f..aede715 100644 --- a/dvuploader/dvuploader.py +++ b/dvuploader/dvuploader.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional from pydantic import BaseModel -from rich.progress import Progress, TaskID +from rich.progress import Progress from rich.table import Table from rich.console import Console from rich.panel import Panel @@ -26,11 +26,25 @@ class DVUploader(BaseModel): Attributes: files (List[File]): A list of File objects to be uploaded. + verbose (bool): Whether to print progress and status messages. Defaults to True. Methods: - upload(persistent_id: str, dataverse_url: str, api_token: str) -> None: - Uploads the files to the specified Dataverse repository in parallel. - + upload(persistent_id: str, dataverse_url: str, api_token: str, n_parallel_uploads: int = 1, force_native: bool = False, replace_existing: bool = True) -> None: + Uploads the files to the specified Dataverse repository. + _validate_files() -> None: + Validates and hashes the files to be uploaded. + _validate_file(file: File) -> None: + Validates and hashes a single file. + _check_duplicates(dataverse_url: str, persistent_id: str, api_token: str, replace_existing: bool) -> None: + Checks for duplicate files in the dataset. + _get_file_id(file: File, ds_files: List[Dict]) -> Optional[str]: + Gets the file ID for a given file in a dataset. + _check_hashes(file: File, dsFile: Dict) -> bool: + Checks if a file has the same checksum as a file in the dataset. + _has_direct_upload(dataverse_url: str, api_token: str, persistent_id: str) -> bool: + Checks if direct upload is supported by the Dataverse instance. + setup_progress_bars(files: List[File]) -> Tuple[Progress, List[TaskID]]: + Sets up progress bars for tracking file uploads. """ files: List[File] @@ -43,15 +57,19 @@ def upload( api_token: str, n_parallel_uploads: int = 1, force_native: bool = False, + replace_existing: bool = True, ) -> None: """ - Uploads the files to the specified Dataverse repository in parallel. + Uploads the files to the specified Dataverse repository. Args: persistent_id (str): The persistent identifier of the Dataverse dataset. dataverse_url (str): The URL of the Dataverse repository. api_token (str): The API token for the Dataverse repository. - n_parallel_uploads (int): The number of parallel uploads to execute. In the case of direct upload, this restricts the amount of parallel chunks per upload. Please use n_jobs to control parallel files. + n_parallel_uploads (int): The number of parallel uploads to execute. For direct upload, + this restricts parallel chunks per upload. Use n_jobs to control parallel files. + force_native (bool): Forces the use of the native upload method instead of direct upload. + replace_existing (bool): Whether to replace files that already exist in the dataset. Returns: None @@ -77,13 +95,14 @@ def upload( if self.verbose: rich.print(panel) - asyncio.run(self._validate_and_hash_files(verbose=self.verbose)) + asyncio.run(self._validate_files()) # Check for duplicates self._check_duplicates( dataverse_url=dataverse_url, persistent_id=persistent_id, api_token=api_token, + replace_existing=replace_existing, ) # Sort files by size @@ -144,7 +163,7 @@ def upload( if self.verbose: rich.print("\n[bold italic white]āœ… Upload complete\n") - async def _validate_and_hash_files(self, verbose: bool): + async def _validate_files(self): """ Validates and hashes the files to be uploaded. @@ -152,62 +171,41 @@ async def _validate_and_hash_files(self, verbose: bool): None """ - if not verbose: - tasks = [ - self._validate_and_hash_file(file=file, verbose=self.verbose) - for file in self.files - ] - - await asyncio.gather(*tasks) - return - - print("\n") - - progress = Progress() - task = progress.add_task( - "[bold italic white]\nšŸ“¦ Preparing upload[/bold italic white]", - total=len(self.files), - ) - - with progress: - tasks = [ - self._validate_and_hash_file( - file=file, progress=progress, task_id=task, verbose=self.verbose - ) - for file in self.files - ] - - await asyncio.gather(*tasks) + tasks = [self._validate_file(file=file) for file in self.files] - print("\n") + await asyncio.gather(*tasks) @staticmethod - async def _validate_and_hash_file( - file: File, - verbose: bool, - progress: Optional[Progress] = None, - task_id: Optional[TaskID] = None, - ): - file.extract_file_name_hash_file() + async def _validate_file(file: File): + """ + Validates and hashes a single file. - if verbose: - progress.update(task_id, advance=1) # type: ignore + Args: + file (File): The file to validate and hash. + + Returns: + None + """ + file.extract_file_name() def _check_duplicates( self, dataverse_url: str, persistent_id: str, api_token: str, + replace_existing: bool, ): """ - Checks for duplicate files in the dataset by comparing the checksums. + Checks for duplicate files in the dataset by comparing paths and filenames. - Parameters: - dataverse_url (str): The URL of the dataverse. + Args: + dataverse_url (str): The URL of the Dataverse repository. persistent_id (str): The persistent ID of the dataset. - api_token (str): The API token for accessing the dataverse. + api_token (str): The API token for accessing the Dataverse repository. + replace_existing (bool): Whether to replace files that already exist. - Prints a message for each file that already exists in the dataset with the same checksum. + Returns: + None """ ds_files = retrieve_dataset_files( @@ -224,43 +222,48 @@ def _check_duplicates( table.add_column("Status") table.add_column("Action") - to_remove = [] over_threshold = len(self.files) > 50 + to_skip = [] n_new_files = 0 - n_skip_files = 0 + n_replace_files = 0 for file in self.files: - has_same_hash = any( - map(lambda dsFile: self._check_hashes(file, dsFile), ds_files) - ) + # If present in dataset, replace file + file.file_id = self._get_file_id(file, ds_files) + file.to_replace = True if file.file_id else False - if has_same_hash: - n_skip_files += 1 - table.add_row( - file.file_name, "[bright_black]Same hash", "[bright_black]Skip" - ) - to_remove.append(file) + if file.to_replace: + n_replace_files += 1 + to_skip.append(file.file_id) + + if replace_existing: + table.add_row( + file.file_name, "[bright_cyan]Exists", "[bright_black]Replace" + ) + else: + table.add_row( + file.file_name, "[bright_cyan]Exists", "[bright_black]Skipping" + ) else: n_new_files += 1 table.add_row( - file.file_name, "[spring_green3]New", "[spring_green3]Upload" + file.file_name, "[spring_green3]New", "[bright_black]Upload" ) - # If present in dataset, replace file - file.file_id = self._get_file_id(file, ds_files) - file.to_replace = True if file.file_id else False - - for file in to_remove: - self.files.remove(file) - console = Console() + if not replace_existing: + console.print( + f"\nSkipping {len(to_skip)} existing files. Use `replace_existing=True` to replace them.\n" + ) + self.files = [file for file in self.files if not file.to_replace] + if over_threshold: table = Table(title="[bold white]šŸ”Ž Checking dataset files") table.add_column("New", style="spring_green3", no_wrap=True) - table.add_column("Skipped", style="bright_black", no_wrap=True) - table.add_row(str(n_new_files), str(n_skip_files)) + table.add_column("Replace", style="bright_black", no_wrap=True) + table.add_row(str(n_new_files), str(n_replace_files)) if self.verbose: console.print(table) @@ -271,18 +274,14 @@ def _get_file_id( ds_files: List[Dict], ) -> Optional[str]: """ - Get the file ID for a given file in a dataset. + Gets the file ID for a given file in a dataset. Args: file (File): The file object to find the ID for. ds_files (List[Dict]): List of dictionary objects representing dataset files. - persistent_id (str): The persistent ID of the dataset. Returns: - str: The ID of the file. - - Raises: - ValueError: If the file cannot be found in the dataset. + Optional[str]: The ID of the file if found, None otherwise. """ # Find the file that matches label and directory_label @@ -298,12 +297,12 @@ def _check_hashes(file: File, dsFile: Dict): """ Checks if a file has the same checksum as a file in the dataset. - Parameters: + Args: file (File): The file to check. - dsFile (Dict): The file in the dataset to compare to. + dsFile (Dict): The file in the dataset to compare against. Returns: - bool: True if the files have the same checksum, False otherwise. + bool: True if the files have matching checksums and paths, False otherwise. """ if not file.checksum: @@ -326,7 +325,17 @@ def _has_direct_upload( api_token: str, persistent_id: str, ) -> bool: - """Checks if the response from the ticket request contains a direct upload URL""" + """ + Checks if direct upload is supported by the Dataverse instance. + + Args: + dataverse_url (str): The URL of the Dataverse repository. + api_token (str): The API token for the Dataverse repository. + persistent_id (str): The persistent ID of the dataset. + + Returns: + bool: True if direct upload is supported, False otherwise. + """ query = build_url( endpoint=urljoin(dataverse_url, TICKET_ENDPOINT), @@ -345,10 +354,13 @@ def _has_direct_upload( def setup_progress_bars(self, files: List[File]): """ - Sets up progress bars for each file in the uploader. + Sets up progress bars for tracking file uploads. + + Args: + files (List[File]): The list of files to create progress bars for. Returns: - A list of progress bars, one for each file in the uploader. + Tuple[Progress, List[TaskID]]: The Progress object and list of task IDs for the progress bars. """ progress = Progress() diff --git a/dvuploader/file.py b/dvuploader/file.py index 9a42740..8a9bb44 100644 --- a/dvuploader/file.py +++ b/dvuploader/file.py @@ -10,10 +10,11 @@ class File(BaseModel): """ - Represents a file with its properties and methods. + Represents a file with its properties and methods for uploading to Dataverse. Attributes: filepath (str): The path to the file. + handler (Union[BytesIO, StringIO, IO, None]): File handler for reading the file contents. description (str): The description of the file. directory_label (str): The label of the directory where the file is stored. mimeType (str): The MIME type of the file. @@ -24,12 +25,15 @@ class File(BaseModel): file_name (Optional[str]): The name of the file. checksum (Optional[Checksum]): The checksum of the file. to_replace (bool): Indicates if the file should be replaced. - file_id (Optional[str]): The ID of the file. + file_id (Optional[Union[str, int]]): The ID of the file to replace. + + Private Attributes: + _size (int): Size of the file in bytes. Methods: + extract_file_name(): Extracts filename from filepath and initializes file handler. _validate_filepath(path): Validates if the file path exists and is a file. - _extract_file_name_hash_file(): Extracts the file_name from the filepath and calculates the file's checksum. - + apply_checksum(): Calculates and applies the checksum for the file. """ model_config = ConfigDict( @@ -53,9 +57,10 @@ class File(BaseModel): _size: int = PrivateAttr(default=0) - def extract_file_name_hash_file(self): + def extract_file_name(self): """ - Extracts the file_name and calculates the hash of the file. + Extracts the file name from the file path and initializes the file handler. + Also calculates the file size and prepares for checksum calculation. Returns: self: The current instance of the class. @@ -76,8 +81,7 @@ def extract_file_name_hash_file(self): if self.file_name is None: self.file_name = os.path.basename(self.filepath) - self.checksum = Checksum.from_file( - handler=self.handler, + self.checksum = Checksum.from_algo( hash_fun=hash_fun, hash_algo=hash_algo, ) @@ -94,9 +98,26 @@ def _validate_filepath(path): Raises: FileNotFoundError: If the filepath does not exist. - TypeError: If the filepath is not a file. + IsADirectoryError: If the filepath points to a directory instead of a file. """ if not os.path.exists(path): raise FileNotFoundError(f"Filepath {path} does not exist.") elif not os.path.isfile(path): raise IsADirectoryError(f"Filepath {path} is not a file.") + + def apply_checksum(self): + """ + Calculates and applies the checksum for the file. + Must be called after extract_file_name() has initialized the checksum. + + Raises: + AssertionError: If checksum is not initialized or hash function is not set. + """ + assert self.checksum is not None, "Checksum is not calculated." + assert self.checksum._hash_fun is not None, "Checksum hash function is not set." + + self.checksum.apply_checksum() + + def __del__(self): + if self.handler is not None: + self.handler.close() diff --git a/dvuploader/nativeupload.py b/dvuploader/nativeupload.py index 622b986..dc0292b 100644 --- a/dvuploader/nativeupload.py +++ b/dvuploader/nativeupload.py @@ -39,9 +39,11 @@ async def native_upload( api_token (str): The API token for the Dataverse repository. persistent_id (str): The persistent identifier of the Dataverse dataset. n_parallel_uploads (int): The number of parallel uploads to execute. + pbars: List of progress bar IDs to track upload progress. + progress: Progress object to manage progress bars. Returns: - List[requests.Response]: The list of responses for each file upload. + None """ _reset_progress(pbars, progress) @@ -89,7 +91,16 @@ def _validate_upload_responses( responses: List[Tuple], files: List[File], ) -> None: - """Validates the responses of the native upload requests.""" + """ + Validates the responses of the native upload requests. + + Args: + responses (List[Tuple]): List of tuples containing status code and response data. + files (List[File]): List of files that were uploaded. + + Returns: + None + """ for (status, response), file in zip(responses, files): if status == 200: @@ -109,9 +120,10 @@ def _zip_packages( Args: packages (List[Tuple[int, List[File]]]): The packages to be zipped. tmp_dir (str): The temporary directory to store the zip files in. + progress (Progress): Progress object to manage progress bars. Returns: - List[File, TaskID]: The list of zip files. + List[Tuple[TaskID, File]]: List of tuples containing progress bar ID and File object. """ files = [] @@ -120,15 +132,14 @@ def _zip_packages( if len(package) == 1: file = package[0] else: - file = File( - filepath=zip_files( - files=package, - tmp_dir=tmp_dir, - index=index, - ), + path = zip_files( + files=package, + tmp_dir=tmp_dir, + index=index, ) - file.extract_file_name_hash_file() + file = File(filepath=path) + file.extract_file_name() file.mimeType = "application/zip" pbar = progress.add_task( @@ -149,7 +160,8 @@ def _reset_progress( Resets the progress bars to zero. Args: - pbars: The progress bars to reset. + pbars (List[TaskID]): List of progress bar IDs to reset. + progress (Progress): Progress object managing the progress bars. Returns: None @@ -174,14 +186,16 @@ async def _single_native_upload( Uploads a file to a Dataverse repository using the native upload method. Args: - session (httpx.AsyncClient): The aiohttp client session. + session (httpx.AsyncClient): The httpx client session. file (File): The file to be uploaded. persistent_id (str): The persistent identifier of the dataset. - pbar: The progress bar object. - progress: The progress object. + pbar: Progress bar ID for tracking upload progress. + progress: Progress object managing the progress bars. Returns: - tuple: A tuple containing the status code and the JSON response from the upload request. + tuple: A tuple containing: + - int: Status code (200 for success, False for failure) + - dict: JSON response from the upload request """ if not file.to_replace: @@ -227,7 +241,15 @@ async def _single_native_upload( def _get_json_data(file: File) -> Dict: - """Returns the JSON data for the native upload request.""" + """ + Returns the JSON data for the native upload request. + + Args: + file (File): The file to create JSON data for. + + Returns: + Dict: Dictionary containing file metadata for the upload request. + """ return { "description": file.description, "directoryLabel": file.directory_label, @@ -244,15 +266,18 @@ async def _update_metadata( api_token: str, persistent_id: str, ): - """Updates the metadata of the given files in a Dataverse repository. + """ + Updates the metadata of the given files in a Dataverse repository. Args: - session (httpx.AsyncClient): The httpx async client. files (List[File]): The files to update the metadata for. dataverse_url (str): The URL of the Dataverse repository. api_token (str): The API token of the Dataverse repository. persistent_id (str): The persistent identifier of the dataset. + + Raises: + ValueError: If a file is not found in the Dataverse repository. """ file_mapping = _retrieve_file_ids( @@ -296,7 +321,17 @@ async def _update_single_metadata( url: str, file: File, ) -> None: - """Updates the metadata of a single file in a Dataverse repository.""" + """ + Updates the metadata of a single file in a Dataverse repository. + + Args: + session (httpx.AsyncClient): The httpx async client. + url (str): The URL endpoint for updating metadata. + file (File): The file to update metadata for. + + Raises: + ValueError: If metadata update fails. + """ json_data = _get_json_data(file) @@ -329,16 +364,16 @@ def _retrieve_file_ids( dataverse_url: str, api_token: str, ) -> Dict[str, str]: - """Retrieves the file IDs of the given files. + """ + Retrieves the file IDs of files in a dataset. Args: - files (List[File]): The files to retrieve the IDs for. persistent_id (str): The persistent identifier of the dataset. dataverse_url (str): The URL of the Dataverse repository. api_token (str): The API token of the Dataverse repository. Returns: - Dict[str, str]: The list of file IDs. + Dict[str, str]: Dictionary mapping file paths to their IDs. """ # Fetch file metadata @@ -352,7 +387,15 @@ def _retrieve_file_ids( def _create_file_id_path_mapping(files): - """Creates dictionary that maps from directoryLabel + filename to ID""" + """ + Creates dictionary that maps from directoryLabel + filename to ID. + + Args: + files: List of file metadata from Dataverse. + + Returns: + Dict[str, str]: Dictionary mapping file paths to their IDs. + """ mapping = {} for file in files: diff --git a/dvuploader/packaging.py b/dvuploader/packaging.py index eb3710d..c99d4d1 100644 --- a/dvuploader/packaging.py +++ b/dvuploader/packaging.py @@ -12,16 +12,17 @@ ) -def distribute_files(dv_files: List[File]): +def distribute_files(dv_files: List[File]) -> List[Tuple[int, List[File]]]: """ Distributes a list of files into packages based on their sizes. Args: dv_files (List[File]): The list of files to be distributed. - maximum_size (int, optional): The maximum size of each package in bytes. Defaults to 2 * 1024**3. Returns: - List[List[File]]: The distributed packages of files. + List[Tuple[int, List[File]]]: A list of tuples containing package index and list of files. + Files are grouped into packages that don't exceed MAXIMUM_PACKAGE_SIZE. + Files larger than MAXIMUM_PACKAGE_SIZE are placed in their own package. """ packages = [] current_package = [] @@ -56,16 +57,17 @@ def distribute_files(dv_files: List[File]): def _append_and_reset( package: Tuple[int, List[File]], packages: List[Tuple[int, List[File]]], -): +) -> Tuple[List[File], int, int]: """ - Appends the given package to the packages list and resets the package list. + Appends the given package to the packages list and resets the package state. Args: - package (List[File]): The package to be appended. - packages (List[List[File]]): The list of packages. + package (Tuple[int, List[File]]): Tuple containing package index and list of files. + packages (List[Tuple[int, List[File]]]): The list of all packages. Returns: - Tuple[List[File], int]: The updated package list and the count of packages. + Tuple[List[File], int, int]: Empty list for new package, reset size counter (0), + and incremented package index. """ packages.append(package) return [], 0, package[0] + 1 @@ -75,18 +77,20 @@ def zip_files( files: List[File], tmp_dir: str, index: int, -): +) -> str: """ - Zips the given files into a zip file. + Creates a zip file containing the given files. Args: files (List[File]): The files to be zipped. - tmp_dir (str): The temporary directory to store the zip file in. + tmp_dir (str): The temporary directory to store the zip file. + index (int): Index used in the zip filename. Returns: - str: The path to the zip file. + str: The full path to the created zip file. """ - path = os.path.join(tmp_dir, f"package_{index}.zip") + name = f"package_{index}.zip" + path = os.path.join(tmp_dir, name) with zipfile.ZipFile(path, "w") as zip_file: for file in files: @@ -98,15 +102,16 @@ def zip_files( return path -def _create_arcname(file: File): +def _create_arcname(file: File) -> str: """ - Creates the arcname for the given file. + Creates the archive name (path within zip) for the given file. Args: - file (File): The file to create the arcname for. + file (File): The file to create the archive name for. Returns: - str: The arcname for the given file. + str: The archive name - either just the filename, or directory_label/filename + if directory_label is set. """ if file.directory_label is not None: return os.path.join(file.directory_label, file.file_name) # type: ignore diff --git a/dvuploader/utils.py b/dvuploader/utils.py index 47916d0..7794bc9 100644 --- a/dvuploader/utils.py +++ b/dvuploader/utils.py @@ -44,9 +44,10 @@ def retrieve_dataset_files( """ Retrieve the files of a specific dataset from a Dataverse repository. - Parameters: + Args: dataverse_url (str): The base URL of the Dataverse repository. persistent_id (str): The persistent identifier (PID) of the dataset. + api_token (str): API token for authentication. Returns: list: A list of files in the dataset. @@ -76,9 +77,9 @@ def add_directory( Recursively adds all files in the specified directory to a list of File objects. Args: - directory (str): The directory path. + directory (str): The directory path to scan for files. ignore (List[str], optional): A list of regular expressions to ignore certain files or directories. Defaults to [r"^\."]. - rootDirectoryLabel (str, optional): The label to be added to the directory path of each file. Defaults to "". + rootDirectoryLabel (str, optional): The label to be prepended to the directory path of each file. Defaults to "". Returns: List[File]: A list of File objects representing the files in the directory. @@ -114,14 +115,14 @@ def add_directory( def _truncate_path(path: pathlib.Path, to_remove: pathlib.Path): """ - Truncate a path by removing a substring from the beginning. + Truncate a path by removing a prefix path. Args: - path (str): The path to truncate. - to_remove (str): The substring to remove from the beginning of the path. + path (pathlib.Path): The full path to truncate. + to_remove (pathlib.Path): The prefix path to remove. Returns: - str: The truncated path. + str: The truncated path as a string, or empty string if nothing remains after truncation. """ parts = path.parts[len(to_remove.parts) :] @@ -134,14 +135,14 @@ def _truncate_path(path: pathlib.Path, to_remove: pathlib.Path): def part_is_ignored(part, ignore): """ - Check if a part should be ignored based on a list of patterns. + Check if a path part should be ignored based on a list of regex patterns. Args: - part (str): The part to check. - ignore (list): A list of patterns to match against. + part (str): The path part to check. + ignore (List[str]): A list of regex patterns to match against. Returns: - bool: True if the part should be ignored, False otherwise. + bool: True if the part matches any ignore pattern, False otherwise. """ for pattern in ignore: if re.match(pattern, part): @@ -154,14 +155,14 @@ def setup_pbar( progress: Progress, ) -> int: """ - Set up a progress bar for a file. + Set up a progress bar for tracking file upload progress. Args: - fpath (str): The path to the file. - progress (Progress): The progress bar object. + file (File): The File object containing file information. + progress (Progress): The rich Progress instance for displaying progress. Returns: - int: The task ID of the progress bar. + int: The task ID for the created progress bar. """ file_size = file._size diff --git a/tests/integration/test_native_upload.py b/tests/integration/test_native_upload.py index 57c2ed8..821b4fa 100644 --- a/tests/integration/test_native_upload.py +++ b/tests/integration/test_native_upload.py @@ -157,17 +157,16 @@ def test_native_upload_by_handler( assert len(files) == 2 for ex_dir, ex_f in expected: - file = next(file for file in files if file["label"] == ex_f) - assert ( - file["label"] == ex_f - ), f"File label does not match for file {json.dumps(file)}" + assert file["label"] == ex_f, ( + f"File label does not match for file {json.dumps(file)}" + ) - assert ( - file.get("directoryLabel", "") == ex_dir - ), f"Directory label does not match for file {json.dumps(file)}" + assert file.get("directoryLabel", "") == ex_dir, ( + f"Directory label does not match for file {json.dumps(file)}" + ) - assert ( - file["description"] == "This is a test" - ), f"Description does not match for file {json.dumps(file)}" + assert file["description"] == "This is a test", ( + f"Description does not match for file {json.dumps(file)}" + ) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index ba2438c..1205ecc 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -14,7 +14,7 @@ def test_full_input(self): # Act cli_input = _parse_yaml_config(fpath) - [file.extract_file_name_hash_file() for file in cli_input.files] + [file.extract_file_name() for file in cli_input.files] # Assert expected_files = [ diff --git a/tests/unit/test_file.py b/tests/unit/test_file.py index 314c0b1..5232cff 100644 --- a/tests/unit/test_file.py +++ b/tests/unit/test_file.py @@ -13,7 +13,7 @@ def test_read_file(self): directory_label="", ) - file.extract_file_name_hash_file() + file.extract_file_name() # Assert assert file.file_name == "somefile.txt" @@ -29,7 +29,7 @@ def test_read_non_existent_file(self): directory_label="", ) - file.extract_file_name_hash_file() + file.extract_file_name() def test_read_non_file(self): # Arrange @@ -42,4 +42,4 @@ def test_read_non_file(self): directory_label="", ) - file.extract_file_name_hash_file() + file.extract_file_name() diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index f6d906d..bd386d9 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -19,7 +19,7 @@ def test_all_files_added_except_hidden(self): # Act files = add_directory(directory) - [file.extract_file_name_hash_file() for file in files] + [file.extract_file_name() for file in files] # Assert expected_files = [ @@ -33,14 +33,14 @@ def test_all_files_added_except_hidden(self): assert len(files) == len(expected_files), "Wrong number of files" for directory_label, file_name in expected_files: - assert any( - file.file_name == file_name for file in files - ), f"File {file_name} not found in files" + assert any(file.file_name == file_name for file in files), ( + f"File {file_name} not found in files" + ) file = next(filter(lambda file: file.file_name == file_name, files)) - assert ( - file.directory_label == directory_label - ), f"File {file_name} has wrong directory label" + assert file.directory_label == directory_label, ( + f"File {file_name} has wrong directory label" + ) def test_all_files_added_except_hidden_and_dunder(self): # Arrange @@ -48,7 +48,7 @@ def test_all_files_added_except_hidden_and_dunder(self): # Act files = add_directory(directory, ignore=[r"^\.", "__.*__"]) - [file.extract_file_name_hash_file() for file in files] + [file.extract_file_name() for file in files] # Assert expected_files = [ @@ -61,14 +61,14 @@ def test_all_files_added_except_hidden_and_dunder(self): assert len(files) == len(expected_files), "Wrong number of files" for directory_label, file_name in expected_files: - assert any( - file.file_name == file_name for file in files - ), f"File {file_name} not found in files" + assert any(file.file_name == file_name for file in files), ( + f"File {file_name} not found in files" + ) file = next(filter(lambda file: file.file_name == file_name, files)) - assert ( - file.directory_label == directory_label - ), f"File {file_name} has wrong directory label" + assert file.directory_label == directory_label, ( + f"File {file_name} has wrong directory label" + ) class TestBuildUrl: