From c0dfc5e0aec45ebf514eb943189b4400d0dc62a2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 16 Jun 2020 00:31:54 +0800 Subject: [PATCH 01/11] [DLMED] add MedNIST dataset --- monai/application/datasets.py | 81 +++++++++++++++++++++++++++++++++++ monai/application/utils.py | 24 +++++++++++ 2 files changed, 105 insertions(+) create mode 100644 monai/application/datasets.py create mode 100644 monai/application/utils.py diff --git a/monai/application/datasets.py b/monai/application/datasets.py new file mode 100644 index 0000000000..91a8dc2afc --- /dev/null +++ b/monai/application/datasets.py @@ -0,0 +1,81 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tarfile +import wget +import numpy as np +from monai.data import CacheDataset +from .utils import download_url, extractall + + +class MedNISTDataset(Randomizable, CacheDataset): + resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz" + + def __init__(self, root, transform, section, download=False, cache_num=sys.maxsize, cache_rate=1.0, num_workers=0, seed=0): + if not os.path.isdir(root): + raise ValueError("root must be a directory.") + self.root = root + self.section = section + self.set_random_state(seed=seed) + self.tarfile_name = os.path.join(self.root, "MedNIST.tar.gz") + self.dataset_dir = os.path.join(self.root, "MedNIST") + if download: + self._download() + if os.path.exists(self.tarfile_name) and not os.path.exists(self.dataset_dir): + extractall() + if not os.path.exists(self.dataset_dir): + raise RuntimeError('can not find dataset directory, please use download=True to download it.') + data = self._generate_data_list() + super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) + + def randomize(self): + self.rann = self.R.random() + + def _download(self): + if os.path.exists(self.tarfile_name) or os.path.exists(self.dataset_dir): + return + os.makedirs(self.root, exist_ok=True) + download_url(self.resource, self.tarfile_name) + + def _generate_data_list(self): + class_names = sorted([x for x in os.listdir(self.dataset_dir) if os.path.isdir(os.path.join(self.dataset_dir, x))]) + num_class = len(class_names) + image_files = [[os.path.join(self.dataset_dir, class_names[i], x) + for x in os.listdir(os.path.join(self.dataset_dir, class_names[i]))] + for i in range(num_class)] + num_each = [len(image_files[i]) for i in range(num_class)] + image_files_list = [] + image_class = [] + for i in range(num_class): + image_files_list.extend(image_files[i]) + image_class.extend([i] * num_each[i]) + num_total = len(image_class) + + val_frac = 0.1 + test_frac = 0.1 + data = list() + + for i in range(num_total): + self.randomize() + if self.section == "training": + if self.rann < val_frac + test_frac: + continue + elif self.section == "validation": + if self.rann >= val_frac: + continue + elif self.section == "test": + if self.rann < val_frac or self.rann >= val_frac + test_frac: + continue + else: + raise ValueError("section name can only be: training, validation or test.") + data.append({"image": image_files_list[i], "label": image_class[i]}) + return data diff --git a/monai/application/utils.py b/monai/application/utils.py new file mode 100644 index 0000000000..2daae460c4 --- /dev/null +++ b/monai/application/utils.py @@ -0,0 +1,24 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import wget +import tarfile + + +def download_url(url, filename=None): + return wget.download(url, out=filename) + + +def extractall(filepath): + datafile = tarfile.open(filepath) + datafile.extractall() + datafile.close() From a449a9ebdddc3669ba427d51d52a1c0e948e9cbb Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 15 Jun 2020 16:34:49 +0000 Subject: [PATCH 02/11] [MONAI] python code formatting --- monai/application/datasets.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/monai/application/datasets.py b/monai/application/datasets.py index 91a8dc2afc..85fe47a7dd 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -20,7 +20,9 @@ class MedNISTDataset(Randomizable, CacheDataset): resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz" - def __init__(self, root, transform, section, download=False, cache_num=sys.maxsize, cache_rate=1.0, num_workers=0, seed=0): + def __init__( + self, root, transform, section, download=False, cache_num=sys.maxsize, cache_rate=1.0, num_workers=0, seed=0 + ): if not os.path.isdir(root): raise ValueError("root must be a directory.") self.root = root @@ -33,7 +35,7 @@ def __init__(self, root, transform, section, download=False, cache_num=sys.maxsi if os.path.exists(self.tarfile_name) and not os.path.exists(self.dataset_dir): extractall() if not os.path.exists(self.dataset_dir): - raise RuntimeError('can not find dataset directory, please use download=True to download it.') + raise RuntimeError("can not find dataset directory, please use download=True to download it.") data = self._generate_data_list() super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) @@ -47,11 +49,17 @@ def _download(self): download_url(self.resource, self.tarfile_name) def _generate_data_list(self): - class_names = sorted([x for x in os.listdir(self.dataset_dir) if os.path.isdir(os.path.join(self.dataset_dir, x))]) + class_names = sorted( + [x for x in os.listdir(self.dataset_dir) if os.path.isdir(os.path.join(self.dataset_dir, x))] + ) num_class = len(class_names) - image_files = [[os.path.join(self.dataset_dir, class_names[i], x) - for x in os.listdir(os.path.join(self.dataset_dir, class_names[i]))] - for i in range(num_class)] + image_files = [ + [ + os.path.join(self.dataset_dir, class_names[i], x) + for x in os.listdir(os.path.join(self.dataset_dir, class_names[i])) + ] + for i in range(num_class) + ] num_each = [len(image_files[i]) for i in range(num_class)] image_files_list = [] image_class = [] From a18c921e0326eec75715c7cd2436244066ebce2d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 16 Jun 2020 16:42:35 +0800 Subject: [PATCH 03/11] [DLMED] add unit tests --- monai/application/__init__.py | 3 +++ monai/application/datasets.py | 39 ++++++++++++++++++++++++++---- monai/application/utils.py | 45 +++++++++++++++++++++++++++++++---- tests/test_mednistdataset.py | 41 +++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 tests/test_mednistdataset.py diff --git a/monai/application/__init__.py b/monai/application/__init__.py index d0044e3563..5e0ce25a6b 100644 --- a/monai/application/__init__.py +++ b/monai/application/__init__.py @@ -8,3 +8,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .utils import * +from .datasets import * diff --git a/monai/application/datasets.py b/monai/application/datasets.py index 85fe47a7dd..fae419badd 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -10,18 +10,49 @@ # limitations under the License. import os +import sys import tarfile import wget import numpy as np +from typing import Callable, Optional from monai.data import CacheDataset +from monai.transforms import Randomizable from .utils import download_url, extractall class MedNISTDataset(Randomizable, CacheDataset): - resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz" + """ + The Dataset to automatically download MedNIST data and generate items for training, validation or test. + It's based on `CacheDataset` to accelerate the training process. + + Args: + root: target dictionary to download and load MedNIST dataset. + section: expected data section, can be: `training`, `validation` or `test`. + download: whether to download the MedNIST from resource link, default is False. + if expected file already exists, skip downloading even set it to True. + seed: random seed to randomly split training, validation and test datasets, defaut is 0. + transform: transforms to execute operations on input data. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_workers: the number of worker threads to use. + If 0 a single thread will be used. Default is 0. + + """ + resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" + md5 = "0bc7306e7427e00ad1c5526a6677552d" def __init__( - self, root, transform, section, download=False, cache_num=sys.maxsize, cache_rate=1.0, num_workers=0, seed=0 + self, + root: str, + section: str, + download: bool = False, + seed: int = 0, + transform: Optional[Callable] = None, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_workers: int = 0 ): if not os.path.isdir(root): raise ValueError("root must be a directory.") @@ -33,7 +64,7 @@ def __init__( if download: self._download() if os.path.exists(self.tarfile_name) and not os.path.exists(self.dataset_dir): - extractall() + extractall(self.tarfile_name, self.root) if not os.path.exists(self.dataset_dir): raise RuntimeError("can not find dataset directory, please use download=True to download it.") data = self._generate_data_list() @@ -46,7 +77,7 @@ def _download(self): if os.path.exists(self.tarfile_name) or os.path.exists(self.dataset_dir): return os.makedirs(self.root, exist_ok=True) - download_url(self.resource, self.tarfile_name) + download_url(self.resource, self.tarfile_name, self.md5) def _generate_data_list(self): class_names = sorted( diff --git a/monai/application/utils.py b/monai/application/utils.py index 2daae460c4..cfb1be89a2 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -10,15 +10,50 @@ # limitations under the License. import os -import wget +import urllib +import hashlib import tarfile +from monai.utils import process_bar -def download_url(url, filename=None): - return wget.download(url, out=filename) +def download_url(url: str, filepath: str, md5_value: str = None): + """ + Download file from specified URL link, support process bar and MD5 check. + Args: + url: source URL link to download file. + filepath: target filepath to save the downloaded file. + md5_value: expected MD5 value to validate the downloaded file. + if None, skip MD5 validation. -def extractall(filepath): + """ + def _process_hook(blocknum, blocksize, totalsize): + process_bar(blocknum * blocksize, totalsize) + + try: + urllib.request.urlretrieve(url, filepath, reporthook=_process_hook) + print(f"\ndownloaded file: {filepath}.") + except (urllib.error.URLError, IOError) as e: + raise e + + if md5_value is not None: + md5 = hashlib.md5() + with open(filepath, 'rb') as f: + for chunk in iter(lambda: f.read(1024 * 1024), b''): + md5.update(chunk) + if md5_value != md5.hexdigest(): + raise RuntimeError("MD5 check of downloaded file failed.") + + +def extractall(filepath: str, output_dir: str = None): + """ + Extract file to the output directory. + + Args: + filepath: the file path of compressed file. + output_dir: target directory to save extracted files. + defaut is None to save in current directory. + """ datafile = tarfile.open(filepath) - datafile.extractall() + datafile.extractall(output_dir) datafile.close() diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py new file mode 100644 index 0000000000..cee3226817 --- /dev/null +++ b/tests/test_mednistdataset.py @@ -0,0 +1,41 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import tempfile + +from monai.application import MedNISTDataset +from tests.utils import NumpyImageTestCase2D +from monai.transforms import LoadPNGd, AddChanneld, ScaleIntensityd, ToTensord, Compose + + +class TestMedNISTDataset(unittest.TestCase): + def test_values(self): + tempdir = tempfile.mkdtemp() + transform = Compose([ + LoadPNGd(keys="image"), + AddChanneld(keys="image"), + ScaleIntensityd(keys="image"), + ToTensord(keys=["image", "label"]) + ]) + dataset = MedNISTDataset(root=tempdir, transform=transform, section="test", download=True) + self.assertEqual(len(dataset), 5986) + self.assertTrue("image" in dataset[0]) + self.assertTrue("label" in dataset[0]) + self.assertTrue("image_meta_dict" in dataset[0]) + self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) + shutil.rmtree(tempdir) + + +if __name__ == "__main__": + unittest.main() From 17a8eb78f17cd17065d17eea6d6bb0ee8063564b Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 16 Jun 2020 08:44:54 +0000 Subject: [PATCH 04/11] [MONAI] python code formatting --- monai/application/datasets.py | 3 ++- monai/application/utils.py | 5 +++-- tests/test_mednistdataset.py | 14 ++++++++------ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/monai/application/datasets.py b/monai/application/datasets.py index fae419badd..d0c9ce11d0 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -40,6 +40,7 @@ class MedNISTDataset(Randomizable, CacheDataset): If 0 a single thread will be used. Default is 0. """ + resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" md5 = "0bc7306e7427e00ad1c5526a6677552d" @@ -52,7 +53,7 @@ def __init__( transform: Optional[Callable] = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, - num_workers: int = 0 + num_workers: int = 0, ): if not os.path.isdir(root): raise ValueError("root must be a directory.") diff --git a/monai/application/utils.py b/monai/application/utils.py index cfb1be89a2..b8c2196d99 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -27,6 +27,7 @@ def download_url(url: str, filepath: str, md5_value: str = None): if None, skip MD5 validation. """ + def _process_hook(blocknum, blocksize, totalsize): process_bar(blocknum * blocksize, totalsize) @@ -38,8 +39,8 @@ def _process_hook(blocknum, blocksize, totalsize): if md5_value is not None: md5 = hashlib.md5() - with open(filepath, 'rb') as f: - for chunk in iter(lambda: f.read(1024 * 1024), b''): + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): md5.update(chunk) if md5_value != md5.hexdigest(): raise RuntimeError("MD5 check of downloaded file failed.") diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index cee3226817..c41182bd83 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -22,12 +22,14 @@ class TestMedNISTDataset(unittest.TestCase): def test_values(self): tempdir = tempfile.mkdtemp() - transform = Compose([ - LoadPNGd(keys="image"), - AddChanneld(keys="image"), - ScaleIntensityd(keys="image"), - ToTensord(keys=["image", "label"]) - ]) + transform = Compose( + [ + LoadPNGd(keys="image"), + AddChanneld(keys="image"), + ScaleIntensityd(keys="image"), + ToTensord(keys=["image", "label"]), + ] + ) dataset = MedNISTDataset(root=tempdir, transform=transform, section="test", download=True) self.assertEqual(len(dataset), 5986) self.assertTrue("image" in dataset[0]) From 2277ead5212cf64564ebe672e328f74b04824a30 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 16 Jun 2020 16:51:06 +0800 Subject: [PATCH 05/11] [DLMED] remove typo --- monai/application/datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/application/datasets.py b/monai/application/datasets.py index d0c9ce11d0..4f7c9a88d3 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -12,7 +12,6 @@ import os import sys import tarfile -import wget import numpy as np from typing import Callable, Optional from monai.data import CacheDataset From db8e580e834a90620e1ebcfb5f4c9981d042cacc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Jun 2020 00:22:51 +0800 Subject: [PATCH 06/11] [DLMED] update according to the comments --- monai/application/datasets.py | 36 ++++++++++++++++------------------- monai/application/utils.py | 26 ++++++++++++++++++++++++- tests/test_mednistdataset.py | 2 +- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/monai/application/datasets.py b/monai/application/datasets.py index 4f7c9a88d3..85dba97260 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -13,7 +13,7 @@ import sys import tarfile import numpy as np -from typing import Callable, Optional +from typing import Any, Callable from monai.data import CacheDataset from monai.transforms import Randomizable from .utils import download_url, extractall @@ -29,6 +29,8 @@ class MedNISTDataset(Randomizable, CacheDataset): section: expected data section, can be: `training`, `validation` or `test`. download: whether to download the MedNIST from resource link, default is False. if expected file already exists, skip downloading even set it to True. + user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory. + extract: whether to extract the MedNIST.tar.gz file under root directory, default is False. seed: random seed to randomly split training, validation and test datasets, defaut is 0. transform: transforms to execute operations on input data. cache_num: number of items to be cached. Default is `sys.maxsize`. @@ -48,46 +50,40 @@ def __init__( root: str, section: str, download: bool = False, + extract: bool = False, seed: int = 0, - transform: Optional[Callable] = None, + transform: Callable[..., Any] = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, ): if not os.path.isdir(root): raise ValueError("root must be a directory.") - self.root = root self.section = section self.set_random_state(seed=seed) - self.tarfile_name = os.path.join(self.root, "MedNIST.tar.gz") - self.dataset_dir = os.path.join(self.root, "MedNIST") + tarfile_name = os.path.join(root, "MedNIST.tar.gz") + dataset_dir = os.path.join(root, "MedNIST") if download: - self._download() - if os.path.exists(self.tarfile_name) and not os.path.exists(self.dataset_dir): - extractall(self.tarfile_name, self.root) - if not os.path.exists(self.dataset_dir): + download_url(self.resource, tarfile_name, self.md5) + if extract: + extractall(tarfile_name, root) + if not os.path.exists(dataset_dir): raise RuntimeError("can not find dataset directory, please use download=True to download it.") - data = self._generate_data_list() + data = self._generate_data_list(dataset_dir) super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) def randomize(self): self.rann = self.R.random() - def _download(self): - if os.path.exists(self.tarfile_name) or os.path.exists(self.dataset_dir): - return - os.makedirs(self.root, exist_ok=True) - download_url(self.resource, self.tarfile_name, self.md5) - - def _generate_data_list(self): + def _generate_data_list(self, dataset_dir): class_names = sorted( - [x for x in os.listdir(self.dataset_dir) if os.path.isdir(os.path.join(self.dataset_dir, x))] + [x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))] ) num_class = len(class_names) image_files = [ [ - os.path.join(self.dataset_dir, class_names[i], x) - for x in os.listdir(os.path.join(self.dataset_dir, class_names[i])) + os.path.join(dataset_dir, class_names[i], x) + for x in os.listdir(os.path.join(dataset_dir, class_names[i])) ] for i in range(num_class) ] diff --git a/monai/application/utils.py b/monai/application/utils.py index b8c2196d99..f1cea0b03d 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -27,6 +27,10 @@ def download_url(url: str, filepath: str, md5_value: str = None): if None, skip MD5 validation. """ + if os.path.exists(filepath): + print(f"file {filepath} exists, skip downloading.") + return + os.makedirs(os.path.dirname(filepath), exist_ok=True) def _process_hook(blocknum, blocksize, totalsize): process_bar(blocknum * blocksize, totalsize) @@ -43,7 +47,8 @@ def _process_hook(blocknum, blocksize, totalsize): for chunk in iter(lambda: f.read(1024 * 1024), b""): md5.update(chunk) if md5_value != md5.hexdigest(): - raise RuntimeError("MD5 check of downloaded file failed.") + raise RuntimeError(f"MD5 check of downloaded file failed, \ + URL={url}, filepath={filepath}, expected MD5={md5_value}") def extractall(filepath: str, output_dir: str = None): @@ -55,6 +60,25 @@ def extractall(filepath: str, output_dir: str = None): output_dir: target directory to save extracted files. defaut is None to save in current directory. """ + target_file = os.path.join(output_dir, os.path.basename(filepath).split(".")[0]) + if os.path.exists(target_file): + print(f"extracted file {target_file} exists, skip extracting.") datafile = tarfile.open(filepath) datafile.extractall(output_dir) datafile.close() + + +def download_and_extract(url: str, filepath: str, md5_value: str = None, output_dir: str = None): + """ + Download file from URL and extract it to the output directory. + + Args: + url: source URL link to download file. + filepath: the file path of compressed file. + md5_value: expected MD5 value to validate the downloaded file. + if None, skip MD5 validation. + output_dir: target directory to save extracted files. + defaut is None to save in current directory. + """ + download_url(url=url, filepath=filepath, md5_value=md5_value) + extractall(filepath=filepath, output_dir=output_dir) diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index c41182bd83..cc472a5798 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -30,7 +30,7 @@ def test_values(self): ToTensord(keys=["image", "label"]), ] ) - dataset = MedNISTDataset(root=tempdir, transform=transform, section="test", download=True) + dataset = MedNISTDataset(root=tempdir, transform=transform, section="test", download=True, extract=True) self.assertEqual(len(dataset), 5986) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) From 9d2a93eee57107bfc2593ec3c48851f1b8b8e495 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 16 Jun 2020 16:24:36 +0000 Subject: [PATCH 07/11] [MONAI] python code formatting --- monai/application/datasets.py | 4 +--- monai/application/utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/application/datasets.py b/monai/application/datasets.py index 85dba97260..a603723b26 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -76,9 +76,7 @@ def randomize(self): self.rann = self.R.random() def _generate_data_list(self, dataset_dir): - class_names = sorted( - [x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))] - ) + class_names = sorted([x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))]) num_class = len(class_names) image_files = [ [ diff --git a/monai/application/utils.py b/monai/application/utils.py index f1cea0b03d..8b3ed18eb9 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -47,8 +47,10 @@ def _process_hook(blocknum, blocksize, totalsize): for chunk in iter(lambda: f.read(1024 * 1024), b""): md5.update(chunk) if md5_value != md5.hexdigest(): - raise RuntimeError(f"MD5 check of downloaded file failed, \ - URL={url}, filepath={filepath}, expected MD5={md5_value}") + raise RuntimeError( + f"MD5 check of downloaded file failed, \ + URL={url}, filepath={filepath}, expected MD5={md5_value}" + ) def extractall(filepath: str, output_dir: str = None): From 84c5f8143fcbcd18288b75061fb05f56bd7855f1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Jun 2020 00:42:00 +0800 Subject: [PATCH 08/11] [DLMED] fix type hints error --- monai/application/datasets.py | 4 ++-- monai/application/utils.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/monai/application/datasets.py b/monai/application/datasets.py index a603723b26..dee7d5072c 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -27,12 +27,12 @@ class MedNISTDataset(Randomizable, CacheDataset): Args: root: target dictionary to download and load MedNIST dataset. section: expected data section, can be: `training`, `validation` or `test`. + transform: transforms to execute operations on input data. download: whether to download the MedNIST from resource link, default is False. if expected file already exists, skip downloading even set it to True. user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory. extract: whether to extract the MedNIST.tar.gz file under root directory, default is False. seed: random seed to randomly split training, validation and test datasets, defaut is 0. - transform: transforms to execute operations on input data. cache_num: number of items to be cached. Default is `sys.maxsize`. will take the minimum of (cache_num, data_length x cache_rate, data_length). cache_rate: percentage of cached data in total, default is 1.0 (cache all). @@ -49,10 +49,10 @@ def __init__( self, root: str, section: str, + transform: Callable[..., Any], download: bool = False, extract: bool = False, seed: int = 0, - transform: Callable[..., Any] = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, diff --git a/monai/application/utils.py b/monai/application/utils.py index 8b3ed18eb9..3623d3ea76 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -10,7 +10,8 @@ # limitations under the License. import os -import urllib +from urllib.request import urlretrieve +from urllib.error import URLError import hashlib import tarfile from monai.utils import process_bar @@ -36,9 +37,9 @@ def _process_hook(blocknum, blocksize, totalsize): process_bar(blocknum * blocksize, totalsize) try: - urllib.request.urlretrieve(url, filepath, reporthook=_process_hook) + urlretrieve(url, filepath, reporthook=_process_hook) print(f"\ndownloaded file: {filepath}.") - except (urllib.error.URLError, IOError) as e: + except (URLError, IOError) as e: raise e if md5_value is not None: @@ -53,14 +54,14 @@ def _process_hook(blocknum, blocksize, totalsize): ) -def extractall(filepath: str, output_dir: str = None): +def extractall(filepath: str, output_dir: str): """ Extract file to the output directory. Args: filepath: the file path of compressed file. output_dir: target directory to save extracted files. - defaut is None to save in current directory. + """ target_file = os.path.join(output_dir, os.path.basename(filepath).split(".")[0]) if os.path.exists(target_file): @@ -70,17 +71,18 @@ def extractall(filepath: str, output_dir: str = None): datafile.close() -def download_and_extract(url: str, filepath: str, md5_value: str = None, output_dir: str = None): +def download_and_extract(url: str, filepath: str, output_dir: str, md5_value: str = None): """ Download file from URL and extract it to the output directory. Args: url: source URL link to download file. filepath: the file path of compressed file. - md5_value: expected MD5 value to validate the downloaded file. - if None, skip MD5 validation. output_dir: target directory to save extracted files. defaut is None to save in current directory. + md5_value: expected MD5 value to validate the downloaded file. + if None, skip MD5 validation. + """ download_url(url=url, filepath=filepath, md5_value=md5_value) extractall(filepath=filepath, output_dir=output_dir) From d14b724ff4f4f8ac93a0d7661a381b37e3080088 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Jun 2020 01:04:10 +0800 Subject: [PATCH 09/11] [DLMED] update check_md5 --- monai/application/utils.py | 39 ++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/monai/application/utils.py b/monai/application/utils.py index 3623d3ea76..bcde4484b4 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -17,6 +17,28 @@ from monai.utils import process_bar +def check_md5(filepath: str, md5_value: str = None): + """ + check MD5 signature of specified file. + + Args: + filepath: path of source file to verify MD5. + md5_value: expected MD5 value of the file. + + """ + if md5_value is not None: + md5 = hashlib.md5() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + md5.update(chunk) + if md5_value != md5.hexdigest(): + return False + else: + print(f"expected MD5 is None, skip MD5 check for file {filepath}.") + + return True + + def download_url(url: str, filepath: str, md5_value: str = None): """ Download file from specified URL link, support process bar and MD5 check. @@ -29,6 +51,8 @@ def download_url(url: str, filepath: str, md5_value: str = None): """ if os.path.exists(filepath): + if not check_md5(filepath, md5_value): + raise RuntimeError(f"MD5 check of existing file {filepath} failed, please delete it and try again.") print(f"file {filepath} exists, skip downloading.") return os.makedirs(os.path.dirname(filepath), exist_ok=True) @@ -42,16 +66,11 @@ def _process_hook(blocknum, blocksize, totalsize): except (URLError, IOError) as e: raise e - if md5_value is not None: - md5 = hashlib.md5() - with open(filepath, "rb") as f: - for chunk in iter(lambda: f.read(1024 * 1024), b""): - md5.update(chunk) - if md5_value != md5.hexdigest(): - raise RuntimeError( - f"MD5 check of downloaded file failed, \ - URL={url}, filepath={filepath}, expected MD5={md5_value}" - ) + if not check_md5(filepath, md5_value): + raise RuntimeError( + f"MD5 check of downloaded file failed, \ + URL={url}, filepath={filepath}, expected MD5={md5_value}." + ) def extractall(filepath: str, output_dir: str): From 472317b53fab6c490c8e6eaa7e6aa3025f2f0c3d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Jun 2020 23:22:13 +0800 Subject: [PATCH 10/11] [DLMED] update according to comments --- monai/application/datasets.py | 38 +++++++++++++----------- monai/application/utils.py | 18 ++++++++++-- tests/test_check_md5.py | 43 ++++++++++++++++++++++++++++ tests/test_download_and_extract.py | 46 ++++++++++++++++++++++++++++++ tests/test_mednistdataset.py | 24 ++++++++++++---- 5 files changed, 143 insertions(+), 26 deletions(-) create mode 100644 tests/test_check_md5.py create mode 100644 tests/test_download_and_extract.py diff --git a/monai/application/datasets.py b/monai/application/datasets.py index dee7d5072c..69aa39a05b 100644 --- a/monai/application/datasets.py +++ b/monai/application/datasets.py @@ -16,7 +16,7 @@ from typing import Any, Callable from monai.data import CacheDataset from monai.transforms import Randomizable -from .utils import download_url, extractall +from .utils import download_and_extract class MedNISTDataset(Randomizable, CacheDataset): @@ -25,14 +25,16 @@ class MedNISTDataset(Randomizable, CacheDataset): It's based on `CacheDataset` to accelerate the training process. Args: - root: target dictionary to download and load MedNIST dataset. + root_dir: target directory to download and load MedNIST dataset. section: expected data section, can be: `training`, `validation` or `test`. transform: transforms to execute operations on input data. - download: whether to download the MedNIST from resource link, default is False. + download: whether to download and extract the MedNIST from resource link, default is False. if expected file already exists, skip downloading even set it to True. user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory. extract: whether to extract the MedNIST.tar.gz file under root directory, default is False. seed: random seed to randomly split training, validation and test datasets, defaut is 0. + val_frac: percentage of of validation fraction in the whole dataset, default is 0.1. + test_frac: percentage of of test fraction in the whole dataset, default is 0.1. cache_num: number of items to be cached. Default is `sys.maxsize`. will take the minimum of (cache_num, data_length x cache_rate, data_length). cache_rate: percentage of cached data in total, default is 1.0 (cache all). @@ -44,29 +46,33 @@ class MedNISTDataset(Randomizable, CacheDataset): resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" md5 = "0bc7306e7427e00ad1c5526a6677552d" + compressed_file_name = "MedNIST.tar.gz" + dataset_folder_name = "MedNIST" def __init__( self, - root: str, + root_dir: str, section: str, transform: Callable[..., Any], download: bool = False, - extract: bool = False, seed: int = 0, + val_frac: float = 0.1, + test_frac: float = 0.1, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, ): - if not os.path.isdir(root): - raise ValueError("root must be a directory.") + if not os.path.isdir(root_dir): + raise ValueError("root_dir must be a directory.") self.section = section + self.val_frac = val_frac + self.test_frac = test_frac self.set_random_state(seed=seed) - tarfile_name = os.path.join(root, "MedNIST.tar.gz") - dataset_dir = os.path.join(root, "MedNIST") + tarfile_name = os.path.join(root_dir, self.compressed_file_name) + dataset_dir = os.path.join(root_dir, self.dataset_folder_name) if download: - download_url(self.resource, tarfile_name, self.md5) - if extract: - extractall(tarfile_name, root) + download_and_extract(self.resource, tarfile_name, root_dir, self.md5) + if not os.path.exists(dataset_dir): raise RuntimeError("can not find dataset directory, please use download=True to download it.") data = self._generate_data_list(dataset_dir) @@ -93,20 +99,18 @@ def _generate_data_list(self, dataset_dir): image_class.extend([i] * num_each[i]) num_total = len(image_class) - val_frac = 0.1 - test_frac = 0.1 data = list() for i in range(num_total): self.randomize() if self.section == "training": - if self.rann < val_frac + test_frac: + if self.rann < self.val_frac + self.test_frac: continue elif self.section == "validation": - if self.rann >= val_frac: + if self.rann >= self.val_frac: continue elif self.section == "test": - if self.rann < val_frac or self.rann >= val_frac + test_frac: + if self.rann < self.val_frac or self.rann >= self.val_frac + self.test_frac: continue else: raise ValueError("section name can only be: training, validation or test.") diff --git a/monai/application/utils.py b/monai/application/utils.py index bcde4484b4..ca10eb871d 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -14,6 +14,7 @@ from urllib.error import URLError import hashlib import tarfile +import zipfile from monai.utils import process_bar @@ -73,19 +74,30 @@ def _process_hook(blocknum, blocksize, totalsize): ) -def extractall(filepath: str, output_dir: str): +def extractall(filepath: str, output_dir: str, md5_value: str = None): """ Extract file to the output directory. + Expected file types are: `zip`, `tar.gz` and `tar`. Args: filepath: the file path of compressed file. output_dir: target directory to save extracted files. + md5_value: expected MD5 value to validate the compressed file. + if None, skip MD5 validation. """ target_file = os.path.join(output_dir, os.path.basename(filepath).split(".")[0]) if os.path.exists(target_file): print(f"extracted file {target_file} exists, skip extracting.") - datafile = tarfile.open(filepath) + return + if not check_md5(filepath, md5_value): + raise RuntimeError(f"MD5 check of compressed file {filepath} failed.") + if filepath.endswith("zip"): + datafile = zipfile.open(filepath) + elif filepath.endswith("tar") or filepath.endswith("tar.gz"): + datafile = tarfile.open(filepath) + else: + raise TypeError("unsupported compressed file type.") datafile.extractall(output_dir) datafile.close() @@ -104,4 +116,4 @@ def download_and_extract(url: str, filepath: str, output_dir: str, md5_value: st """ download_url(url=url, filepath=filepath, md5_value=md5_value) - extractall(filepath=filepath, output_dir=output_dir) + extractall(filepath=filepath, output_dir=output_dir, md5_value=md5_value) diff --git a/tests/test_check_md5.py b/tests/test_check_md5.py new file mode 100644 index 0000000000..335b8b70f1 --- /dev/null +++ b/tests/test_check_md5.py @@ -0,0 +1,43 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import numpy as np +import tempfile +from PIL import Image +from parameterized import parameterized +from monai.application import check_md5 + +TEST_CASE_1 = ["f38e9e043c8e902321e827b24ce2e5ec", True] + +TEST_CASE_2 = ["12c730d4e7427e00ad1c5526a6677535", False] + +TEST_CASE_3 = [None, True] + + +class TestCheckMD5(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, md5_value, expected_result): + test_image = np.ones((64, 64, 3)) + tempdir = tempfile.mkdtemp() + filename = os.path.join(tempdir, "test_file.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + + result = check_md5(filename, md5_value) + self.assertTrue(result == expected_result) + + shutil.rmtree(tempdir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py new file mode 100644 index 0000000000..bc52c15744 --- /dev/null +++ b/tests/test_download_and_extract.py @@ -0,0 +1,46 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import numpy as np +import tempfile +from monai.application import download_and_extract, download_url, extractall + + +class TestDownloadAndExtract(unittest.TestCase): + def test_actions(self): + tempdir = tempfile.mkdtemp() + url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" + filepath = os.path.join(tempdir, "MedNIST.tar.gz") + output_dir = tempdir + md5_value = "0bc7306e7427e00ad1c5526a6677552d" + download_and_extract(url, filepath, output_dir, md5_value) + download_and_extract(url, filepath, output_dir, md5_value) + + wrong_md5 = "0" + try: + download_url(url, filepath, wrong_md5) + except RuntimeError as e: + self.assertTrue(str(e).startswith("MD5 check")) + + shutil.rmtree(os.path.join(tempdir, "MedNIST")) + try: + extractall(filepath, output_dir, wrong_md5) + except RuntimeError as e: + self.assertTrue(str(e).startswith("MD5 check")) + + shutil.rmtree(tempdir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index cc472a5798..44e245ae97 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -30,12 +30,24 @@ def test_values(self): ToTensord(keys=["image", "label"]), ] ) - dataset = MedNISTDataset(root=tempdir, transform=transform, section="test", download=True, extract=True) - self.assertEqual(len(dataset), 5986) - self.assertTrue("image" in dataset[0]) - self.assertTrue("label" in dataset[0]) - self.assertTrue("image_meta_dict" in dataset[0]) - self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) + + def _test_dataset(dataset): + self.assertEqual(len(dataset), 5986) + self.assertTrue("image" in dataset[0]) + self.assertTrue("label" in dataset[0]) + self.assertTrue("image_meta_dict" in dataset[0]) + self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) + + data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=True) + _test_dataset(data) + data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=False) + _test_dataset(data) + shutil.rmtree(os.path.join(tempdir, "MedNIST")) + try: + data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=False) + except RuntimeError as e: + self.assertTrue(str(e).startswith("can not find dataset directory")) + shutil.rmtree(tempdir) From 7878ab95f215f3309031d07738896880d030bea5 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 17 Jun 2020 23:29:57 +0800 Subject: [PATCH 11/11] [DLMED] fix type hints --- monai/application/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/application/utils.py b/monai/application/utils.py index ca10eb871d..71a6a815ea 100644 --- a/monai/application/utils.py +++ b/monai/application/utils.py @@ -92,14 +92,17 @@ def extractall(filepath: str, output_dir: str, md5_value: str = None): return if not check_md5(filepath, md5_value): raise RuntimeError(f"MD5 check of compressed file {filepath} failed.") + if filepath.endswith("zip"): - datafile = zipfile.open(filepath) + zip_file = zipfile.ZipFile(filepath) + zip_file.extractall(output_dir) + zip_file.close() elif filepath.endswith("tar") or filepath.endswith("tar.gz"): - datafile = tarfile.open(filepath) + tar_file = tarfile.open(filepath) + tar_file.extractall(output_dir) + tar_file.close() else: raise TypeError("unsupported compressed file type.") - datafile.extractall(output_dir) - datafile.close() def download_and_extract(url: str, filepath: str, output_dir: str, md5_value: str = None):