diff --git a/docs/howto/asyncDownloadMultiple.md b/docs/howto/asyncDownloadMultiple.md new file mode 100644 index 00000000..4dee1153 --- /dev/null +++ b/docs/howto/asyncDownloadMultiple.md @@ -0,0 +1,186 @@ +## Asynchronous Multiple File Downloads + +The Gen3 SDK provides an optimized asynchronous download method `async_download_multiple` for efficiently downloading large numbers of files with high throughput and memory efficiency. + +## Overview + +The `async_download_multiple` method implements a hybrid architecture combining: + +- **Multiprocessing**: Multiple Python subprocesses for CPU utilization +- **Asyncio**: High I/O concurrency within each process +- **Queue-based memory management**: Efficient handling of large file sets +- **Just-in-time presigned URL generation**: Optimized authentication flow + +## Architecture + +### Concurrency Model + +The implementation uses a three-tier architecture: + +1. **Producer Thread**: Feeds GUIDs to worker processes via bounded queues +2. **Worker Processes**: Multiple Python subprocesses with asyncio event loops +3. **Queue System**: Memory-efficient streaming of work items + +```python +# Architecture overview +Producer Thread → Input Queue → Worker Processes → Output Queue → Results + (1) (configurable) (configurable) (configurable) (Final) +``` + +### Key Features + +- **Memory Efficiency**: Bounded queues prevent memory explosion with large file sets +- **True Parallelism**: Multiprocessing bypasses Python GIL limitations +- **High Concurrency**: Configurable concurrent downloads per process +- **Resume Support**: Skip completed files with `--skip-completed` flag +- **Progress Tracking**: Real-time progress bars and detailed reporting + +## Usage + +### Command Line Interface + +Download multiple files using a manifest: + +```bash +gen3 --endpoint my-commons.org --auth credentials.json download-multiple \ + --manifest files.json \ + --download-path ./downloads \ + --max-concurrent-requests 10 \ + --filename-format original \ + --skip-completed \ + --no-prompt +``` + +### Python API + +The `async_download_multiple` method is available in the `Gen3File` class for programmatic use. Refer to the Python SDK documentation for the complete API reference. + +## Parameters + +For detailed parameter information and current default values, run: + +```bash +gen3 download-multiple --help +``` + +The command supports various options for customizing download behavior, including concurrency settings, file naming strategies, and progress controls. + +## Performance Characteristics + +### Throughput Optimization + +The method is optimized for high-throughput scenarios: + +- **Concurrent Downloads**: Configurable number of simultaneous downloads +- **Memory Usage**: Bounded by queue sizes (typically < 100MB) +- **CPU Utilization**: Leverages multiple CPU cores +- **Network Efficiency**: Just-in-time presigned URL generation + +### Scalability + +Performance scales with: + +- **File Count**: Linear time complexity with constant memory usage +- **File Size**: Independent of individual file sizes +- **Network Bandwidth**: Limited by available bandwidth and concurrent connections +- **System Resources**: Scales with available CPU cores and memory + +## Error Handling + +### Robust Error Recovery + +The implementation includes comprehensive error handling: + +- **Network Failures**: Automatic retry with exponential backoff +- **Authentication Errors**: Token refresh and retry +- **File System Errors**: Graceful handling of permission and space issues +- **Process Failures**: Automatic worker process restart + +### Result Reporting + +The method returns a structured result object containing lists of succeeded, failed, and skipped downloads with detailed information about each operation. + +## Best Practices + +### Configuration Recommendations + +For optimal performance, adjust the concurrency and process settings based on your specific use case: + +- **Small files**: Use higher concurrent request limits +- **Large files**: Use lower concurrent request limits to avoid overwhelming the system +- **High-bandwidth networks**: Increase the number of worker processes +- **Limited memory**: Reduce queue sizes to manage memory usage + + +## Comparison with Synchronous Downloads + +### Performance Advantages + +| Metric | Synchronous | Asynchronous | +| ------------------ | ---------------------------- | ---------------------------- | +| Memory Usage | O(n) - grows with file count | O(1) - bounded by queue size | +| CPU Utilization | Single core | Multiple cores | +| Network Efficiency | Sequential | Parallel | +| Scalability | Limited by GIL | Scales with CPU cores | + +## Troubleshooting + +### Common Issues + +**Slow Downloads:** + +- Check network bandwidth and server limits +- Reduce concurrent request limits if server is overwhelmed + +**Memory Issues:** + +- Reduce queue sizes and batch sizes +- Lower the number of worker processes if system memory is limited +- Monitor system memory usage during downloads + +**Authentication Errors:** + +- Verify credentials file is valid and not expired +- Check endpoint URL is correct +- Ensure proper permissions for target files + +**Process Failures:** + +- Check system resources (CPU, memory, file descriptors) +- Verify network connectivity to Gen3 commons +- Review logs for specific error messages + +### Debugging + +Enable verbose logging for detailed debugging: + +```bash +gen3 -vv --endpoint my-commons.org --auth credentials.json download-multiple \ + --manifest files.json \ + --download-path ./downloads +``` + +## Examples + +### Basic Usage + +```bash +# Download files with default settings +gen3 --endpoint data.commons.io --auth creds.json download-multiple \ + --manifest my_files.json \ + --download-path ./data +``` + +### High-Performance Configuration + +```bash +# Optimized for high-throughput downloads +gen3 --endpoint data.commons.io --auth creds.json download-multiple \ + --manifest large_dataset.json \ + --download-path ./large_downloads \ + --max-concurrent-requests 8 \ + --no-progress \ + --skip-completed +``` + +**Note**: The specific values shown in examples (like `--max-concurrent-requests 8`) are for demonstration only. For current parameter options and default values, always refer to the command line help: `gen3 download-multiple --help` diff --git a/gen3/cli/__main__.py b/gen3/cli/__main__.py index 5a51be11..a701efb5 100644 --- a/gen3/cli/__main__.py +++ b/gen3/cli/__main__.py @@ -15,6 +15,7 @@ import gen3.cli.drs_pull as drs_pull import gen3.cli.users as users import gen3.cli.wrap as wrap +import gen3.cli.download as download import gen3 from gen3 import logging as sdklogging from gen3.cli import nih @@ -142,6 +143,8 @@ def main( main.add_command(objects.objects) main.add_command(drs_pull.drs_pull) main.add_command(file.file) +main.add_command(download.download_single, name="download-single") +main.add_command(download.download_multiple, name="download-multiple") main.add_command(nih.nih) main.add_command(users.users) main.add_command(wrap.run) diff --git a/gen3/cli/download.py b/gen3/cli/download.py new file mode 100644 index 00000000..854382ae --- /dev/null +++ b/gen3/cli/download.py @@ -0,0 +1,288 @@ +""" +Gen3 download commands for CLI. +""" + +import asyncio +import json +from datetime import datetime +from typing import List, Dict, Any + +import click + +from cdislogging import get_logger +from gen3.file import Gen3File + +logging = get_logger("__name__") + + +def get_or_create_event_loop_for_thread(): + """Get or create event loop for current thread.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + +def load_manifest(manifest_path: str) -> List[Dict[str, Any]]: + """Load manifest from JSON file. + + Args: + manifest_path (str): Path to the manifest JSON file. + + Returns: + List[Dict[str, Any]]: List of dictionaries containing file information. + + Raises: + FileNotFoundError: If the manifest file does not exist. + json.JSONDecodeError: If the manifest file contains invalid JSON. + """ + try: + with open(manifest_path, "r") as f: + return json.load(f) + except (FileNotFoundError, json.JSONDecodeError) as e: + raise click.ClickException(f"Error loading manifest: {e}") + + +def validate_manifest(manifest_data: List[Dict[str, Any]]) -> bool: + """Validate manifest structure. + + Args: + manifest_data (List[Dict[str, Any]]): List of dictionaries to validate. + + Returns: + bool: True if manifest is valid, False otherwise. + """ + if not isinstance(manifest_data, list): + return False + + for item in manifest_data: + if not isinstance(item, dict): + return False + if not (item.get("guid") or item.get("object_id")): + return False + + return True + + +@click.command() +@click.argument("guid") +@click.option( + "--download-path", + default=f"download_{datetime.now().strftime('%d_%b_%Y')}", + help="Directory to download file to (default: timestamped folder)", +) +@click.option( + "--filename-format", + default="original", + type=click.Choice(["original", "guid", "combined"]), + help="Filename format: 'original' uses the original filename from metadata, 'guid' uses only the file GUID, 'combined' uses original filename with GUID appended (default: original)", +) +@click.option( + "--protocol", + default=None, + help="Protocol for presigned URL (e.g., s3) (default: auto-detect)", +) +@click.option( + "--skip-completed", + type=bool, + default=True, + help="Skip files that already exist (default: true)", +) +@click.option( + "--rename", is_flag=True, help="Rename file if it already exists (default: false)" +) +@click.option( + "--no-prompt", is_flag=True, help="Do not prompt for confirmations (default: false)" +) +@click.option( + "--no-progress", is_flag=True, help="Disable progress bar (default: false)" +) +@click.pass_context +def download_single( + ctx, + guid, + download_path, + filename_format, + protocol, + skip_completed=True, + rename=False, + no_prompt=False, + no_progress=False, +): + """Download a single file by GUID.""" + auth = ctx.obj["auth_factory"].get() + + try: + file_client = Gen3File(auth_provider=auth) + + result = file_client.download_single( + guid=guid, + download_path=download_path, + filename_format=filename_format, + protocol=protocol, + skip_completed=skip_completed, + rename=rename, + ) + + if result["status"] == "downloaded": + click.echo(f"✓ Downloaded: {result['filepath']}") + elif result["status"] == "skipped": + click.echo(f"- Skipped: {result.get('reason', 'Already exists')}") + else: + click.echo(f"✗ Failed: {result.get('error', 'Unknown error')}") + raise click.ClickException("Download failed") + + except Exception as e: + logging.error(f"Download failed: {e}") + raise click.ClickException(f"Download failed: {e}") + + +@click.command() +@click.option("--manifest", required=True, help="Path to manifest JSON file") +@click.option( + "--download-path", + default=f"download_{datetime.now().strftime('%d_%b_%Y')}", + help="Directory to download files to (default: timestamped folder)", +) +@click.option( + "--filename-format", + default="original", + type=click.Choice(["original", "guid", "combined"]), + help="Filename format: 'original' uses the original filename from metadata, 'guid' uses only the file GUID, 'combined' uses original filename with GUID appended (default: original)", +) +@click.option( + "--protocol", + default=None, + help="Protocol for presigned URLs (e.g., s3) (default: auto-detect)", +) +@click.option( + "--max-concurrent-requests", + default=20, + help="Maximum concurrent async downloads per process (default: 20)", + type=int, +) +@click.option( + "--numparallel", + default=None, + help="Number of downloads to run in parallel (compatibility with gen3-client)", + type=int, +) +@click.option( + "--num-processes", + default=3, + help="Number of worker processes for parallel downloads (default: 3)", + type=int, +) +@click.option( + "--queue-size", + default=1000, + help="Maximum items in input queue (default: 1000)", + type=int, +) +@click.option( + "--skip-completed", + type=bool, + default=True, + help="Skip files that already exist (default: true)", +) +@click.option( + "--rename", is_flag=True, help="Rename files if they already exist (default: false)" +) +@click.option( + "--no-prompt", is_flag=True, help="Do not prompt for confirmations (default: false)" +) +@click.option( + "--no-progress", is_flag=True, help="Disable progress bar (default: false)" +) +@click.pass_context +def download_multiple( + ctx, + manifest, + download_path, + filename_format, + protocol, + max_concurrent_requests, + numparallel, + num_processes, + queue_size, + skip_completed=True, + rename=False, + no_prompt=False, + no_progress=False, +): + """ + Asynchronously download multiple files from a manifest with just-in-time presigned URL generation. + """ + auth = ctx.obj["auth_factory"].get() + + # Use numparallel as max_concurrent_requests if provided (for gen3-client compatibility) + if numparallel is not None and max_concurrent_requests == 20: # 20 is the default + max_concurrent_requests = numparallel + + try: + manifest_data = load_manifest(manifest) + + if not validate_manifest(manifest_data): + raise click.ClickException("Invalid manifest format") + + if not manifest_data: + click.echo("No files to download") + return + + if not no_prompt: + click.echo(f"Found {len(manifest_data)} files to download") + if not click.confirm("Continue with async download?"): + click.echo("Download cancelled") + return + + file_client = Gen3File(auth_provider=auth) + + # Debug logging for input parameters + logging.debug( + f"Async download parameters: manifest_data={len(manifest_data)} items, download_path={download_path}, filename_format={filename_format}, protocol={protocol}, max_concurrent_requests={max_concurrent_requests}, skip_completed={skip_completed}, rename={rename}, no_progress={no_progress}" + ) + + loop = get_or_create_event_loop_for_thread() + result = loop.run_until_complete( + file_client.async_download_multiple( + manifest_data=manifest_data, + download_path=download_path, + filename_format=filename_format, + protocol=protocol, + max_concurrent_requests=max_concurrent_requests, + num_processes=num_processes, + queue_size=queue_size, + skip_completed=skip_completed, + rename=rename, + no_progress=no_progress, + ) + ) + + click.echo("\nAsync Download Results:") + click.echo(f"✓ Succeeded: {len(result['succeeded'])}") + + if result["skipped"] and len(result["skipped"]) > 0: + click.echo(f"- Skipped: {len(result['skipped'])}") + + if result["failed"] and len(result["failed"]) > 0: + click.echo(f"✗ Failed: {len(result['failed'])}") + + if result["failed"]: + click.echo("\nFailed downloads:") + for failure in result["failed"]: + click.echo( + f" - {failure.get('guid', 'unknown')}: {failure.get('error', 'Unknown error')}" + ) + + click.echo( + "\nTo retry failed downloads, run the same command with --skip-completed flag:" + ) + + success_rate = len(result["succeeded"]) / len(manifest_data) * 100 + click.echo(f"\nSuccess rate: {success_rate:.1f}%") + + except Exception as e: + logging.error(f"Async batch download failed: {e}") + raise click.ClickException(f"Async batch download failed: {e}") diff --git a/gen3/cli/pfb.py b/gen3/cli/pfb.py index ddf03e0d..0b275768 100644 --- a/gen3/cli/pfb.py +++ b/gen3/cli/pfb.py @@ -23,5 +23,15 @@ def pfb(): pfb.add_command(pfb_cli.main.get_command(ctx=None, cmd_name=command)) # load plug-ins from entry_points -for ep in entry_points().get("gen3.plugins", []): - ep.load() +try: + # For newer Python versions (3.10+) + if hasattr(entry_points(), "select"): + for ep in entry_points().select(group="gen3.plugins"): + ep.load() + else: + # For older Python versions + for ep in entry_points().get("gen3.plugins", []): + ep.load() +except Exception: + # Skip plugin loading if it fails + pass diff --git a/gen3/cli/users.py b/gen3/cli/users.py index b6aad66a..605ef386 100644 --- a/gen3/cli/users.py +++ b/gen3/cli/users.py @@ -25,5 +25,15 @@ def users(): users.add_command(users_cli.main.get_command(ctx=None, cmd_name=command)) # load plug-ins from entry_points -for ep in entry_points().get("gen3.plugins", []): - ep.load() +try: + # For newer Python versions (3.10+) + if hasattr(entry_points(), "select"): + for ep in entry_points().select(group="gen3.plugins"): + ep.load() + else: + # For older Python versions + for ep in entry_points().get("gen3.plugins", []): + ep.load() +except Exception: + # Skip plugin loading if it fails + pass diff --git a/gen3/file.py b/gen3/file.py index 75161574..bb7b4c35 100644 --- a/gen3/file.py +++ b/gen3/file.py @@ -1,26 +1,30 @@ import json import requests -import json import asyncio import aiohttp import aiofiles import time +import multiprocessing as mp +import threading from tqdm import tqdm -from types import SimpleNamespace as Namespace import os -import requests from pathlib import Path +from typing import List, Dict, Any, Optional +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse, quote +from queue import Empty from cdislogging import get_logger from gen3.index import Gen3Index -from gen3.utils import DEFAULT_BACKOFF_SETTINGS, raise_for_status_and_print_error -from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse +from gen3.utils import raise_for_status_and_print_error logging = get_logger("__name__") MAX_RETRIES = 3 +DEFAULT_NUM_PARALLEL = 3 +DEFAULT_MAX_CONCURRENT_REQUESTS = 300 +DEFAULT_QUEUE_SIZE = 1000 class Gen3File: @@ -45,7 +49,6 @@ def __init__(self, endpoint=None, auth_provider=None): # auth_provider legacy interface required endpoint as 1st arg self._auth_provider = auth_provider or endpoint self._endpoint = self._auth_provider.endpoint - self.unsuccessful_downloads = [] def get_presigned_url(self, guid, protocol=None): """Generates a presigned URL for a file. @@ -67,10 +70,26 @@ def get_presigned_url(self, guid, protocol=None): resp = requests.get(api_url, auth=self._auth_provider) raise_for_status_and_print_error(resp) - try: - return resp.json() - except: - return resp.text + return resp.json() + + def get_presigned_urls_batch(self, guids, protocol=None): + """Get presigned URLs for multiple files efficiently. + + Args: + guids (List[str]): List of GUIDs to get presigned URLs for + protocol (str, optional): Protocol preference for URLs + + Returns: + Dict[str, Dict]: Mapping of GUID to presigned URL response + """ + results = {} + for guid in guids: + try: + results[guid] = self.get_presigned_url(guid, protocol) + except Exception as e: + logging.error(f"Failed to get presigned URL for {guid}: {e}") + results[guid] = None + return results def delete_file(self, guid): """ @@ -140,12 +159,7 @@ def upload_file( api_url, auth=self._auth_provider, json=body, headers=headers ) raise_for_status_and_print_error(resp) - try: - data = json.loads(resp.text) - except: - return resp.text - - return data + return resp.json() def _ensure_dirpath_exists(path: Path) -> Path: """Utility to create a directory if missing. @@ -163,19 +177,83 @@ def _ensure_dirpath_exists(path: Path) -> Path: return out_path - def download_single(self, object_id, path): + def download_single( + self, + object_id=None, + path=None, + guid=None, + download_path=None, + filename_format="original", + protocol=None, + skip_completed=True, + rename=False, + ): """ Download a single file using its GUID. Args: - object_id (str): The file's unique ID - path (str): Path to store the downloaded file at + object_id (str): The file's unique ID (legacy parameter) + path (str): Path to store the downloaded file at (legacy parameter) + guid (str): The file's unique ID (new parameter name) + download_path (str): Path to store the downloaded file at (new parameter name) + filename_format (str): Format for filename (original, guid, combined) + protocol (str): Protocol for presigned URL + skip_completed (bool): Skip files that already exist + rename (bool): Rename file if it already exists + + Returns: + dict: Status information about the download """ + # Handle both old and new parameter names + if object_id is not None: + file_guid = object_id + elif guid is not None: + file_guid = guid + else: + raise ValueError("Either object_id or guid must be provided") + + if path is not None: + download_dir = path + elif download_path is not None: + download_dir = download_path + else: + download_dir = "." try: - url = self.get_presigned_url(object_id) + url = self.get_presigned_url(file_guid) except Exception as e: logging.critical(f"Unable to get a presigned URL for download: {e}") - return False + return {"status": "failed", "error": str(e)} + + # Get file metadata + index = Gen3Index(self._auth_provider) + record = index.get_record(file_guid) + + # Determine filename based on format + if filename_format == "guid": + filename = file_guid + elif filename_format == "combined": + original_filename = record.get("file_name", file_guid) + filename = f"{original_filename}_{file_guid}" + else: # original + filename = record.get("file_name", file_guid) + + # Check if file already exists and handle accordingly + out_path = Gen3File._ensure_dirpath_exists(Path(download_dir)) + filepath = os.path.join(out_path, filename) + + if os.path.exists(filepath) and skip_completed: + return { + "status": "skipped", + "reason": "File already exists", + "filepath": filepath, + } + elif os.path.exists(filepath) and rename: + counter = 1 + base_name, ext = os.path.splitext(filename) + while os.path.exists(filepath): + filename = f"{base_name}_{counter}{ext}" + filepath = os.path.join(out_path, filename) + counter += 1 response = requests.get(url["url"], stream=True) if response.status_code != 200: @@ -186,39 +264,38 @@ def download_single(self, object_id, path): # NOTE could be updated with exponential backoff time.sleep(1) response = requests.get(url["url"], stream=True) - if response.status == 200: + if response.status_code == 200: break - if response.status != 200: + if response.status_code != 200: logging.critical("Response status not 200, try again later") - return False + return { + "status": "failed", + "error": "Server error, try again later", + } else: - return False + return {"status": "failed", "error": f"HTTP {response.status_code}"} response.raise_for_status() - total_size_in_bytes = int(response.headers.get("content-length")) + total_size_in_bytes = int(response.headers.get("content-length", 0)) total_downloaded = 0 - index = Gen3Index(self._auth_provider) - record = index.get_record(object_id) - - filename = record["file_name"] - - out_path = Gen3File._ensure_dirpath_exists(Path(path)) - - with open(os.path.join(out_path, filename), "wb") as f: + with open(filepath, "wb") as f: for data in response.iter_content(4096): total_downloaded += len(data) f.write(data) - if total_size_in_bytes == total_downloaded: + if total_size_in_bytes > 0 and total_size_in_bytes == total_downloaded: logging.info(f"File {filename} downloaded successfully") - + return {"status": "downloaded", "filepath": filepath} + elif total_size_in_bytes == 0: + logging.info(f"File {filename} downloaded successfully (unknown size)") + return {"status": "downloaded", "filepath": filepath} else: logging.error(f"File {filename} not downloaded successfully") - return False + return {"status": "failed", "error": "Download incomplete"} - return True + return {"status": "downloaded", "filepath": filepath} def upload_file_to_guid( self, guid, file_name, protocol=None, expires_in=None, bucket=None @@ -259,3 +336,438 @@ def upload_file_to_guid( resp = requests.get(url, auth=self._auth_provider) raise_for_status_and_print_error(resp) return resp.json() + + async def async_download_multiple( + self, + manifest_data, + download_path=".", + filename_format="original", + protocol=None, + max_concurrent_requests=DEFAULT_MAX_CONCURRENT_REQUESTS, + num_processes=DEFAULT_NUM_PARALLEL, + queue_size=DEFAULT_QUEUE_SIZE, + skip_completed=False, + rename=False, + no_progress=False, + ): + """Asynchronously download multiple files using multiprocessing and queues.""" + if not manifest_data: + return {"succeeded": [], "failed": [], "skipped": []} + + guids = [] + for item in manifest_data: + guid = item.get("guid") or item.get("object_id") + if guid: + if "/" in guid: + guid = guid.split("/")[-1] + guids.append(guid) + + if not guids: + logging.error("No valid GUIDs found in manifest data") + return {"succeeded": [], "failed": [], "skipped": []} + + output_dir = Gen3File._ensure_dirpath_exists(Path(download_path)) + + input_queue = mp.Queue(maxsize=queue_size) + output_queue = mp.Queue() + + worker_config = { + "endpoint": self._endpoint, + "auth_provider": self._auth_provider, + "download_path": str(output_dir), + "filename_format": filename_format, + "protocol": protocol, + "max_concurrent": max_concurrent_requests, + "skip_completed": skip_completed, + "rename": rename, + } + + processes = [] + producer_thread = None + + try: + for i in range(num_processes): + p = mp.Process( + target=self._async_worker_process, + args=(input_queue, output_queue, worker_config, i), + ) + p.start() + processes.append(p) + + producer_thread = threading.Thread( + target=self._guid_producer, + args=(guids, input_queue, num_processes), + ) + producer_thread.start() + + results = {"succeeded": [], "failed": [], "skipped": []} + completed_count = 0 + + if not no_progress: + pbar = tqdm(total=len(guids), desc="Downloading") + + while completed_count < len(guids): + try: + batch_results = output_queue.get() + + if not batch_results: + continue + + for result in batch_results: + if result["status"] == "downloaded": + results["succeeded"].append(result["guid"]) + elif result["status"] == "skipped": + results["skipped"].append(result["guid"]) + else: + results["failed"].append(result["guid"]) + + completed_count += 1 + if not no_progress: + pbar.update(1) + + except Empty: + logging.warning( + f"No more results available ({completed_count}/{len(guids)}): Queue is empty" + ) + break + except Exception as e: + logging.warning( + f"Error waiting for results ({completed_count}/{len(guids)}): {e}" + ) + + alive_processes = [p for p in processes if p.is_alive()] + if not alive_processes: + logging.error("All worker processes have died") + break + + if not no_progress: + pbar.close() + + if producer_thread: + producer_thread.join() + + except Exception as e: + logging.error(f"Error in download: {e}") + results = {"succeeded": [], "failed": [], "skipped": [], "error": str(e)} + + finally: + for p in processes: + if p.is_alive(): + p.terminate() + + p.join() + if p.is_alive(): + p.kill() + + logging.info( + f"Download complete: {len(results['succeeded'])} succeeded, " + f"{len(results['failed'])} failed, {len(results['skipped'])} skipped" + ) + return results + + def _guid_producer(self, guids, input_queue, num_processes): + try: + for guid in guids: + input_queue.put(guid) + + except Exception as e: + logging.error(f"Error in producer: {e}") + + @staticmethod + def _async_worker_process(input_queue, output_queue, config, process_id): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete( + Gen3File._worker_main(input_queue, output_queue, config, process_id) + ) + except Exception as e: + logging.error(f"Error in worker process {process_id}: {e}") + finally: + try: + loop.close() + except Exception as e: + logging.warning(f"Error closing event loop in worker {process_id}: {e}") + + @staticmethod + async def _worker_main(input_queue, output_queue, config, process_id): + endpoint = config["endpoint"] + auth_provider = config["auth_provider"] + download_path = Path(config["download_path"]) + filename_format = config["filename_format"] + protocol = config["protocol"] + max_concurrent = config["max_concurrent"] + skip_completed = config["skip_completed"] + rename = config["rename"] + + # Configure connector with optimized settings for large files + timeout = aiohttp.ClientTimeout(total=None, connect=3600, sock_read=3600) + connector = aiohttp.TCPConnector( + limit=max_concurrent * 2, + limit_per_host=max_concurrent, + ttl_dns_cache=300, + use_dns_cache=True, + keepalive_timeout=3600, + enable_cleanup_closed=True, + ) + semaphore = asyncio.Semaphore(max_concurrent) + + async with aiohttp.ClientSession( + connector=connector, timeout=timeout + ) as session: + while True: + try: + # Check if queue is empty with timeout + guid = input_queue.get() + except Empty: + # If queue is empty (timeout), break the loop + break + + # Process single GUID as a batch of one + try: + batch_results = await Gen3File._process_batch( + session, + [guid], + endpoint, + auth_provider, + download_path, + filename_format, + protocol, + semaphore, + skip_completed, + rename, + ) + output_queue.put(batch_results) + except Exception as e: + logging.error( + f"Worker {process_id}: Failed to process {guid} - {type(e).__name__}: {e}" + ) + error_result = [{"guid": guid, "status": "failed", "error": str(e)}] + try: + output_queue.put(error_result) + except Exception as queue_error: + logging.error( + f"Worker {process_id}: Failed to send error result for {guid} - {type(queue_error).__name__}: {queue_error}" + ) + + @staticmethod + async def _process_batch( + session, + guids, + endpoint, + auth_provider, + download_path, + filename_format, + protocol, + semaphore, + skip_completed, + rename, + ): + """Process a batch of GUIDs for downloading.""" + batch_results = [] + for guid in guids: + async with semaphore: + result = await Gen3File._download_single_async( + session, + guid, + endpoint, + auth_provider, + download_path, + filename_format, + protocol, + semaphore, + skip_completed, + rename, + ) + batch_results.append(result) + return batch_results + + @staticmethod + async def _download_single_async( + session, + guid, + endpoint, + auth_provider, + download_path, + filename_format, + protocol, + semaphore, + skip_completed, + rename, + ): + async with semaphore: + try: + metadata = await Gen3File._get_metadata( + session, guid, endpoint, auth_provider.get_access_token() + ) + + original_filename = metadata.get("file_name") + filename = Gen3File._format_filename_static( + guid, original_filename, filename_format + ) + filepath = download_path / filename + + if skip_completed and filepath.exists(): + return { + "guid": guid, + "status": "skipped", + "filepath": str(filepath), + "reason": "File already exists", + } + + filepath = Gen3File._handle_conflict_static(filepath, rename) + + presigned_data = await Gen3File._get_presigned_url_async( + session, guid, endpoint, auth_provider.get_access_token(), protocol + ) + + url = presigned_data.get("url") + if not url: + return { + "guid": guid, + "status": "failed", + "error": "No URL in presigned data", + } + + filepath.parent.mkdir(parents=True, exist_ok=True) + + success = await Gen3File._download_content(session, url, guid, filepath) + if success: + return { + "guid": guid, + "status": "downloaded", + "filepath": str(filepath), + "size": filepath.stat().st_size if filepath.exists() else 0, + } + else: + return { + "guid": guid, + "status": "failed", + "error": "Download failed", + } + + except Exception as e: + logging.error(f"Error downloading {guid}: {e}") + return { + "guid": guid, + "status": "failed", + "error": str(e), + } + + @staticmethod + async def _get_metadata(session, guid, endpoint, auth_token): + encoded_guid = quote(guid, safe="") + api_url = f"{endpoint}/index/{encoded_guid}" + headers = {"Authorization": f"Bearer {auth_token}"} + + try: + async with session.get( + api_url, headers=headers, timeout=aiohttp.ClientTimeout(total=3600) + ) as resp: + if resp.status == 200: + return await resp.json() + raise Exception( + f"Failed to get metadata for {guid}: HTTP {resp.status}" + ) + except aiohttp.ClientError as e: + raise Exception(f"Network error getting metadata for {guid}: {e}") + except asyncio.TimeoutError: + raise Exception(f"Timeout getting metadata for {guid}") + except Exception as e: + if "Failed to get metadata" not in str(e): + raise Exception(f"Unexpected error getting metadata for {guid}: {e}") + raise + + @staticmethod + async def _get_presigned_url_async( + session, guid, endpoint, auth_token, protocol=None + ): + encoded_guid = quote(guid, safe="") + api_url = f"{endpoint}/user/data/download/{encoded_guid}" + headers = {"Authorization": f"Bearer {auth_token}"} + + if protocol: + api_url += f"?protocol={protocol}" + + try: + async with session.get( + api_url, headers=headers, timeout=aiohttp.ClientTimeout(total=3600) + ) as resp: + if resp.status == 200: + return await resp.json() + raise Exception( + f"Failed to get presigned URL for {guid}: HTTP {resp.status}" + ) + except aiohttp.ClientError as e: + raise Exception(f"Network error getting presigned URL for {guid}: {e}") + except asyncio.TimeoutError: + raise Exception(f"Timeout getting presigned URL for {guid}") + except Exception as e: + if "Failed to get presigned URL" not in str(e): + raise Exception( + f"Unexpected error getting presigned URL for {guid}: {e}" + ) + raise + + @staticmethod + async def _download_content(session, url, guid, filepath): + """Download content directly to file with optimized streaming.""" + try: + async with session.get( + url, timeout=aiohttp.ClientTimeout(total=None) + ) as resp: + if resp.status == 200: + async with aiofiles.open(filepath, "wb") as f: + chunk_size = 1024 * 1024 + async for chunk in resp.content.iter_chunked(chunk_size): + await f.write(chunk) + return True + logging.error(f"Download failed for {guid}: HTTP {resp.status}") + return False + except aiohttp.ClientError as e: + logging.error(f"Network error downloading {guid}: {e}") + return False + except asyncio.TimeoutError: + logging.error(f"Timeout downloading {guid}") + return False + except OSError as e: + logging.error(f"File system error downloading {guid} to {filepath}: {e}") + return False + except Exception as e: + logging.error( + f"Unexpected error downloading {guid}: {type(e).__name__}: {e}" + ) + return False + + @staticmethod + def _format_filename_static(guid, original_filename, filename_format): + if filename_format == "guid": + return guid + elif filename_format == "combined": + if original_filename: + name, ext = os.path.splitext(original_filename) + return f"{name}_{guid}{ext}" + return guid + else: + return original_filename or guid + + @staticmethod + def _handle_conflict_static(filepath, rename): + if not rename: + if filepath.exists(): + logging.warning(f"File will be overwritten: {filepath}") + return filepath + + if not filepath.exists(): + return filepath + + counter = 1 + name = filepath.stem + ext = filepath.suffix + parent = filepath.parent + + while True: + new_path = parent / f"{name}_{counter}{ext}" + if not new_path.exists(): + return new_path + counter += 1 diff --git a/pyproject.toml b/pyproject.toml index dfafa89c..e22880e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "gen3" homepage = "https://gen3.org/" -version = "4.27.4" +version = "4.28.0" description = "Gen3 CLI and Python SDK" authors = ["Center for Translational Data Science at the University of Chicago "] license = "Apache-2.0" diff --git a/tests/download_tests/test_async_download.py b/tests/download_tests/test_async_download.py index f42442b3..5f4b8c91 100644 --- a/tests/download_tests/test_async_download.py +++ b/tests/download_tests/test_async_download.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import patch, MagicMock, AsyncMock import json import pytest from pathlib import Path @@ -264,3 +264,30 @@ def test_load_manifest_bad_format(self): manifest_list = _load_manifest(Path(DIR, "resources/bad_format.json")) assert manifest_list == None + + @pytest.mark.asyncio + async def test_async_download_multiple_empty_manifest(self, mock_gen3_auth): + """ + Test async_download_multiple with an empty manifest. + Verifies it returns empty results without errors. + """ + file_tool = Gen3File(mock_gen3_auth) + result = await file_tool.async_download_multiple(manifest_data=[]) + + assert result == {"succeeded": [], "failed": [], "skipped": []} + + @pytest.mark.asyncio + async def test_async_download_multiple_invalid_guids(self, mock_gen3_auth): + """ + Test async_download_multiple with invalid GUIDs. + Verifies it returns empty results for missing GUIDs. + """ + file_tool = Gen3File(mock_gen3_auth) + + # Manifest with missing guid/object_id fields + manifest_data = [{"file_name": "test.txt"}, {}] + + result = await file_tool.async_download_multiple(manifest_data=manifest_data) + + assert result == {"succeeded": [], "failed": [], "skipped": []} + diff --git a/tests/test_file.py b/tests/test_file.py index a02a80eb..60e3f641 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -1,9 +1,11 @@ """ Tests gen3.file.Gen3File for calls """ -from unittest.mock import patch -import json + +from unittest.mock import patch, MagicMock import pytest +import tempfile +from pathlib import Path from requests import HTTPError @@ -248,20 +250,37 @@ def test_upload_file( expected response to compare with mock """ with patch("gen3.file.requests") as mock_request: - mock_request.status_code = status_code - mock_request.post().text = response_text - res = gen3_file.upload_file( - file_name="file.txt", - authz=authz, - protocol=supported_protocol, - expires_in=expires_in, + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = response_text + mock_response.json.return_value = ( + expected_response if status_code == 201 else {} ) + + # Make raise_for_status() raise HTTPError for non-2xx status codes + if status_code >= 400: + mock_response.raise_for_status.side_effect = HTTPError() + + mock_request.post.return_value = mock_response + if status_code == 201: + res = gen3_file.upload_file( + file_name="file.txt", + authz=authz, + protocol=supported_protocol, + expires_in=expires_in, + ) # check that the SDK is getting fence assert res.get("url") == expected_response["url"] else: - # check the error message - assert expected_response in res + # For non-201 status codes, the method should raise an exception + with pytest.raises(HTTPError): + gen3_file.upload_file( + file_name="file.txt", + authz=authz, + protocol=supported_protocol, + expires_in=expires_in, + ) @pytest.mark.parametrize( @@ -326,7 +345,7 @@ def test_upload_file_no_refresh_token(gen3_file, supported_protocol, authz, expi def test_upload_file_no_api_key(gen3_file, supported_protocol, authz, expires_in): """ Upload files for a Gen3File given a protocol, authz, and expires_in - without an api_key in the refresh token, which should return a 401 + without an api_key in the refresh token, which should raise an HTTPError :param gen3.file.Gen3File gen3_file: Gen3File object @@ -341,15 +360,19 @@ def test_upload_file_no_api_key(gen3_file, supported_protocol, authz, expires_in gen3_file._auth_provider._refresh_token = {"not_api_key": "123"} with patch("gen3.file.requests") as mock_request: - mock_request.status_code = 401 - mock_request.post().text = "Failed to upload data file." - res = gen3_file.upload_file( - file_name="file.txt", - authz=authz, - protocol=supported_protocol, - expires_in=expires_in, - ) - assert res == "Failed to upload data file." + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Failed to upload data file." + mock_response.raise_for_status.side_effect = HTTPError() + mock_request.post.return_value = mock_response + + with pytest.raises(HTTPError): + gen3_file.upload_file( + file_name="file.txt", + authz=authz, + protocol=supported_protocol, + expires_in=expires_in, + ) @pytest.mark.parametrize( @@ -371,7 +394,7 @@ def test_upload_file_no_api_key(gen3_file, supported_protocol, authz, expires_in def test_upload_file_wrong_api_key(gen3_file, supported_protocol, authz, expires_in): """ Upload files for a Gen3File given a protocol, authz, and expires_in - with the wrong value for the api_key in the refresh token, which should return a 401 + with the wrong value for the api_key in the refresh token, which should raise an HTTPError :param gen3.file.Gen3File gen3_file: Gen3File object @@ -386,12 +409,213 @@ def test_upload_file_wrong_api_key(gen3_file, supported_protocol, authz, expires gen3_file._auth_provider._refresh_token = {"api_key": "wrong_value"} with patch("gen3.file.requests") as mock_request: - mock_request.status_code = 401 - mock_request.post().text = "Failed to upload data file." - res = gen3_file.upload_file( - file_name="file.txt", - authz=authz, - protocol=supported_protocol, - expires_in=expires_in, + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Failed to upload data file." + mock_response.raise_for_status.side_effect = HTTPError() + mock_request.post.return_value = mock_response + + with pytest.raises(HTTPError): + gen3_file.upload_file( + file_name="file.txt", + authz=authz, + protocol=supported_protocol, + expires_in=expires_in, + ) + + +@pytest.fixture +def mock_manifest_data(): + return [ + {"guid": "test-guid-1", "file_name": "file1.txt"}, + {"guid": "test-guid-2", "file_name": "file2.txt"}, + {"object_id": "test-guid-3", "file_name": "file3.txt"}, + ] + + +def test_download_single_success(gen3_file): + """ + Test successful download of a single file via download_single method. + + Verifies that download_single correctly downloads a file using synchronous requests + and returns a success status dictionary. + """ + gen3_file._auth_provider._refresh_token = {"api_key": "123"} + + with ( + patch.object(gen3_file, "get_presigned_url") as mock_presigned, + patch("gen3.file.requests.get") as mock_get, + patch("gen3.index.Gen3Index.get_record") as mock_index, + patch("os.path.exists", return_value=False), + ): + mock_presigned.return_value = {"url": "https://fake-url.com/file"} + mock_index.return_value = {"file_name": "test-file.txt"} + mock_response = MagicMock() + mock_response.status_code = 200 + test_content = b"test content" + mock_response.headers = {"content-length": str(len(test_content))} + mock_response.iter_content = lambda size: [test_content] + mock_get.return_value = mock_response + + result = gen3_file.download_single(object_id="test-guid", path="/tmp") + + assert result["status"] == "downloaded" + assert "filepath" in result + + +def test_download_single_failed(gen3_file): + """ + Test failed download of a single file via download_single method. + + Verifies that download_single correctly handles failures and returns a failure status dictionary. + """ + gen3_file._auth_provider._refresh_token = {"api_key": "123"} + + with ( + patch.object(gen3_file, "get_presigned_url") as mock_presigned, + patch("gen3.file.requests.get") as mock_get, + ): + mock_presigned.return_value = {"url": "https://fake-url.com/file"} + mock_response = MagicMock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + result = gen3_file.download_single(object_id="test-guid", path="/tmp") + + assert result["status"] == "failed" + assert "error" in result + + +@pytest.mark.asyncio +async def test_async_download_multiple_empty_manifest(gen3_file): + """ + Test async_download_multiple with an empty manifest. + + Verifies that calling async_download_multiple with an empty manifest + returns empty succeeded, failed, and skipped lists. + """ + result = await gen3_file.async_download_multiple(manifest_data=[]) + assert result == {"succeeded": [], "failed": [], "skipped": []} + + +@pytest.mark.asyncio +async def test_async_download_multiple_success(gen3_file, mock_manifest_data): + """ + Test successful async download of multiple files. + + Verifies that async_download_multiple correctly processes a manifest with + multiple files and returns all downloads as successful. + """ + gen3_file._auth_provider._refresh_token = {"api_key": "123"} + gen3_file._auth_provider.get_access_token = MagicMock(return_value="fake_token") + + with ( + patch("gen3.file.mp.Process"), + patch("gen3.file.mp.Queue") as mock_queue, + patch("threading.Thread"), + ): + mock_input_queue = MagicMock() + mock_output_queue = MagicMock() + mock_queue.side_effect = [mock_input_queue, mock_output_queue] + + mock_output_queue.get.side_effect = [ + [{"guid": "test-guid-1", "status": "downloaded"}], + [{"guid": "test-guid-2", "status": "downloaded"}], + [{"guid": "test-guid-3", "status": "downloaded"}], + ] + + result = await gen3_file.async_download_multiple( + manifest_data=mock_manifest_data, download_path="/tmp" ) - assert res == "Failed to upload data file." + + assert len(result["succeeded"]) == 3 + + +def test_get_presigned_urls_batch(gen3_file): + """ + Test batch retrieval of presigned URLs for multiple GUIDs. + + Verifies that get_presigned_urls_batch correctly calls get_presigned_url + for each GUID and returns a mapping of results. + """ + gen3_file._auth_provider._refresh_token = {"api_key": "123"} + + with patch.object(gen3_file, "get_presigned_url") as mock_get_url: + mock_get_url.return_value = {"url": "https://example.com/presigned"} + + results = gen3_file.get_presigned_urls_batch(["guid1", "guid2"]) + + assert len(results) == 2 + assert mock_get_url.call_count == 2 + + +def test_format_filename_static(): + """ + Test the static _format_filename_static method with different filename formats. + + Verifies that files can be formatted as original, guid-only, or combined + (filename_guidXXX.ext) based on the format parameter. + """ + from gen3.file import Gen3File + + assert ( + Gen3File._format_filename_static("guid123", "test.txt", "original") + == "test.txt" + ) + assert Gen3File._format_filename_static("guid123", "test.txt", "guid") == "guid123" + assert ( + Gen3File._format_filename_static("guid123", "test.txt", "combined") + == "test_guid123.txt" + ) + + +def test_handle_conflict_static(): + """ + Test the static _handle_conflict_static method for file conflict resolution. + + Verifies that existing files can be either kept or renamed with a numeric + suffix based on the rename parameter. + """ + from gen3.file import Gen3File + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + existing_file = temp_path / "existing.txt" + existing_file.write_text("test") + + result = Gen3File._handle_conflict_static(existing_file, rename=False) + assert result == existing_file + + result = Gen3File._handle_conflict_static(existing_file, rename=True) + assert result.name == "existing_1.txt" + + +def test_download_single_basic_functionality(gen3_file): + """ + Test download_single basic functionality with synchronous download. + + Verifies that download_single downloads a file successfully using + synchronous requests and returns a success status dictionary. + """ + gen3_file._auth_provider._refresh_token = {"api_key": "123"} + + with ( + patch.object(gen3_file, "get_presigned_url") as mock_presigned, + patch("gen3.file.requests.get") as mock_get, + patch("gen3.index.Gen3Index.get_record") as mock_index, + patch("os.path.exists", return_value=False), + ): + mock_presigned.return_value = {"url": "https://fake-url.com/file"} + mock_index.return_value = {"file_name": "test-file.txt"} + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "12"} + mock_response.iter_content = lambda size: [b"test content"] + mock_get.return_value = mock_response + + result = gen3_file.download_single(object_id="test-guid", path="/tmp") + + assert result["status"] == "downloaded" + assert "filepath" in result + mock_presigned.assert_called_once_with("test-guid") + mock_index.assert_called_once_with("test-guid")