From 03e1e132b2cd47a03bea1a5ef7562fc19c6e0b89 Mon Sep 17 00:00:00 2001 From: jerome_Hsieh Date: Sun, 17 Nov 2024 01:47:06 +0800 Subject: [PATCH] enhance download_and_extract --- monai/apps/utils.py | 48 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index c2e17d3247..c6d1d8cd8d 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -327,7 +327,47 @@ def download_and_extract( be False. progress: whether to display progress bar. """ - with tempfile.TemporaryDirectory() as tmp_dir: - filename = filepath or Path(tmp_dir, _basename(url)).resolve() - download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) - extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) +def download_and_extract( + url: str, + filepath: PathLike = "", + output_dir: PathLike = ".", + hash_val: str | None = None, + hash_type: str = "md5", + file_type: str = "", + has_base: bool = True, + progress: bool = True, +) -> None: + """ + Download file from URL and extract it to the output directory. + + Args: + url: source URL link to download file. + filepath: the file path of the downloaded compressed file. + use this option to keep the directly downloaded compressed file, to avoid further repeated downloads. + output_dir: target directory to save extracted files. + default is the current directory. + hash_val: expected hash value to validate the downloaded file. + if None, skip hash validation. + hash_type: 'md5' or 'sha1', defaults to 'md5'. + file_type: string of file type for decompressing. Leave it empty to infer the type from url's base file name. + has_base: whether the extracted files have a base folder. This flag is used when checking if the existing + folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped + to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should + be False. + progress: whether to display progress bar. + """ + urlFilenameExtension = ''.join(Path(".", _basename(url)).resolve().suffixes) + if filepath: + FilepathExtenstion = ''.join(Path(".", _basename(filepath)).resolve().suffixes) + if urlFilenameExtension != FilepathExtenstion: + raise NotImplementedError( + f'The file types do not match: url={urlFilenameExtension}, but filepath={FilepathExtenstion}' + ) + else: + with tempfile.TemporaryDirectory() as tmp_dir: + if filepath: + filename = filepath + else: + filename = Path(tmp_dir, _basename(url)).resolve() + download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) + extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)