diff --git a/monai/application/__init__.py b/monai/application/__init__.py index d0044e3563..5e0ce25a6b 100644 --- a/monai/application/__init__.py +++ b/monai/application/__init__.py @@ -8,3 +8,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .utils import * +from .datasets import * diff --git a/monai/application/datasets.py b/monai/application/datasets.py new file mode 100644 index 0000000000..69aa39a05b --- /dev/null +++ b/monai/application/datasets.py @@ -0,0 +1,118 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tarfile +import numpy as np +from typing import Any, Callable +from monai.data import CacheDataset +from monai.transforms import Randomizable +from .utils import download_and_extract + + +class MedNISTDataset(Randomizable, CacheDataset): + """ + The Dataset to automatically download MedNIST data and generate items for training, validation or test. + It's based on `CacheDataset` to accelerate the training process. + + Args: + root_dir: target directory to download and load MedNIST dataset. + section: expected data section, can be: `training`, `validation` or `test`. + transform: transforms to execute operations on input data. + download: whether to download and extract the MedNIST from resource link, default is False. + if expected file already exists, skip downloading even set it to True. + user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory. + extract: whether to extract the MedNIST.tar.gz file under root directory, default is False. + seed: random seed to randomly split training, validation and test datasets, defaut is 0. + val_frac: percentage of of validation fraction in the whole dataset, default is 0.1. + test_frac: percentage of of test fraction in the whole dataset, default is 0.1. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_workers: the number of worker threads to use. + If 0 a single thread will be used. Default is 0. + + """ + + resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" + md5 = "0bc7306e7427e00ad1c5526a6677552d" + compressed_file_name = "MedNIST.tar.gz" + dataset_folder_name = "MedNIST" + + def __init__( + self, + root_dir: str, + section: str, + transform: Callable[..., Any], + download: bool = False, + seed: int = 0, + val_frac: float = 0.1, + test_frac: float = 0.1, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_workers: int = 0, + ): + if not os.path.isdir(root_dir): + raise ValueError("root_dir must be a directory.") + self.section = section + self.val_frac = val_frac + self.test_frac = test_frac + self.set_random_state(seed=seed) + tarfile_name = os.path.join(root_dir, self.compressed_file_name) + dataset_dir = os.path.join(root_dir, self.dataset_folder_name) + if download: + download_and_extract(self.resource, tarfile_name, root_dir, self.md5) + + if not os.path.exists(dataset_dir): + raise RuntimeError("can not find dataset directory, please use download=True to download it.") + data = self._generate_data_list(dataset_dir) + super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) + + def randomize(self): + self.rann = self.R.random() + + def _generate_data_list(self, dataset_dir): + class_names = sorted([x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))]) + num_class = len(class_names) + image_files = [ + [ + os.path.join(dataset_dir, class_names[i], x) + for x in os.listdir(os.path.join(dataset_dir, class_names[i])) + ] + for i in range(num_class) + ] + num_each = [len(image_files[i]) for i in range(num_class)] + image_files_list = [] + image_class = [] + for i in range(num_class): + image_files_list.extend(image_files[i]) + image_class.extend([i] * num_each[i]) + num_total = len(image_class) + + data = list() + + for i in range(num_total): + self.randomize() + if self.section == "training": + if self.rann < self.val_frac + self.test_frac: + continue + elif self.section == "validation": + if self.rann >= self.val_frac: + continue + elif self.section == "test": + if self.rann < self.val_frac or self.rann >= self.val_frac + self.test_frac: + continue + else: + raise ValueError("section name can only be: training, validation or test.") + data.append({"image": image_files_list[i], "label": image_class[i]}) + return data diff --git a/monai/application/utils.py b/monai/application/utils.py new file mode 100644 index 0000000000..71a6a815ea --- /dev/null +++ b/monai/application/utils.py @@ -0,0 +1,122 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from urllib.request import urlretrieve +from urllib.error import URLError +import hashlib +import tarfile +import zipfile +from monai.utils import process_bar + + +def check_md5(filepath: str, md5_value: str = None): + """ + check MD5 signature of specified file. + + Args: + filepath: path of source file to verify MD5. + md5_value: expected MD5 value of the file. + + """ + if md5_value is not None: + md5 = hashlib.md5() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + md5.update(chunk) + if md5_value != md5.hexdigest(): + return False + else: + print(f"expected MD5 is None, skip MD5 check for file {filepath}.") + + return True + + +def download_url(url: str, filepath: str, md5_value: str = None): + """ + Download file from specified URL link, support process bar and MD5 check. + + Args: + url: source URL link to download file. + filepath: target filepath to save the downloaded file. + md5_value: expected MD5 value to validate the downloaded file. + if None, skip MD5 validation. + + """ + if os.path.exists(filepath): + if not check_md5(filepath, md5_value): + raise RuntimeError(f"MD5 check of existing file {filepath} failed, please delete it and try again.") + print(f"file {filepath} exists, skip downloading.") + return + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + def _process_hook(blocknum, blocksize, totalsize): + process_bar(blocknum * blocksize, totalsize) + + try: + urlretrieve(url, filepath, reporthook=_process_hook) + print(f"\ndownloaded file: {filepath}.") + except (URLError, IOError) as e: + raise e + + if not check_md5(filepath, md5_value): + raise RuntimeError( + f"MD5 check of downloaded file failed, \ + URL={url}, filepath={filepath}, expected MD5={md5_value}." + ) + + +def extractall(filepath: str, output_dir: str, md5_value: str = None): + """ + Extract file to the output directory. + Expected file types are: `zip`, `tar.gz` and `tar`. + + Args: + filepath: the file path of compressed file. + output_dir: target directory to save extracted files. + md5_value: expected MD5 value to validate the compressed file. + if None, skip MD5 validation. + + """ + target_file = os.path.join(output_dir, os.path.basename(filepath).split(".")[0]) + if os.path.exists(target_file): + print(f"extracted file {target_file} exists, skip extracting.") + return + if not check_md5(filepath, md5_value): + raise RuntimeError(f"MD5 check of compressed file {filepath} failed.") + + if filepath.endswith("zip"): + zip_file = zipfile.ZipFile(filepath) + zip_file.extractall(output_dir) + zip_file.close() + elif filepath.endswith("tar") or filepath.endswith("tar.gz"): + tar_file = tarfile.open(filepath) + tar_file.extractall(output_dir) + tar_file.close() + else: + raise TypeError("unsupported compressed file type.") + + +def download_and_extract(url: str, filepath: str, output_dir: str, md5_value: str = None): + """ + Download file from URL and extract it to the output directory. + + Args: + url: source URL link to download file. + filepath: the file path of compressed file. + output_dir: target directory to save extracted files. + defaut is None to save in current directory. + md5_value: expected MD5 value to validate the downloaded file. + if None, skip MD5 validation. + + """ + download_url(url=url, filepath=filepath, md5_value=md5_value) + extractall(filepath=filepath, output_dir=output_dir, md5_value=md5_value) diff --git a/tests/test_check_md5.py b/tests/test_check_md5.py new file mode 100644 index 0000000000..335b8b70f1 --- /dev/null +++ b/tests/test_check_md5.py @@ -0,0 +1,43 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import numpy as np +import tempfile +from PIL import Image +from parameterized import parameterized +from monai.application import check_md5 + +TEST_CASE_1 = ["f38e9e043c8e902321e827b24ce2e5ec", True] + +TEST_CASE_2 = ["12c730d4e7427e00ad1c5526a6677535", False] + +TEST_CASE_3 = [None, True] + + +class TestCheckMD5(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, md5_value, expected_result): + test_image = np.ones((64, 64, 3)) + tempdir = tempfile.mkdtemp() + filename = os.path.join(tempdir, "test_file.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + + result = check_md5(filename, md5_value) + self.assertTrue(result == expected_result) + + shutil.rmtree(tempdir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py new file mode 100644 index 0000000000..bc52c15744 --- /dev/null +++ b/tests/test_download_and_extract.py @@ -0,0 +1,46 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import numpy as np +import tempfile +from monai.application import download_and_extract, download_url, extractall + + +class TestDownloadAndExtract(unittest.TestCase): + def test_actions(self): + tempdir = tempfile.mkdtemp() + url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" + filepath = os.path.join(tempdir, "MedNIST.tar.gz") + output_dir = tempdir + md5_value = "0bc7306e7427e00ad1c5526a6677552d" + download_and_extract(url, filepath, output_dir, md5_value) + download_and_extract(url, filepath, output_dir, md5_value) + + wrong_md5 = "0" + try: + download_url(url, filepath, wrong_md5) + except RuntimeError as e: + self.assertTrue(str(e).startswith("MD5 check")) + + shutil.rmtree(os.path.join(tempdir, "MedNIST")) + try: + extractall(filepath, output_dir, wrong_md5) + except RuntimeError as e: + self.assertTrue(str(e).startswith("MD5 check")) + + shutil.rmtree(tempdir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py new file mode 100644 index 0000000000..44e245ae97 --- /dev/null +++ b/tests/test_mednistdataset.py @@ -0,0 +1,55 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import tempfile + +from monai.application import MedNISTDataset +from tests.utils import NumpyImageTestCase2D +from monai.transforms import LoadPNGd, AddChanneld, ScaleIntensityd, ToTensord, Compose + + +class TestMedNISTDataset(unittest.TestCase): + def test_values(self): + tempdir = tempfile.mkdtemp() + transform = Compose( + [ + LoadPNGd(keys="image"), + AddChanneld(keys="image"), + ScaleIntensityd(keys="image"), + ToTensord(keys=["image", "label"]), + ] + ) + + def _test_dataset(dataset): + self.assertEqual(len(dataset), 5986) + self.assertTrue("image" in dataset[0]) + self.assertTrue("label" in dataset[0]) + self.assertTrue("image_meta_dict" in dataset[0]) + self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) + + data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=True) + _test_dataset(data) + data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=False) + _test_dataset(data) + shutil.rmtree(os.path.join(tempdir, "MedNIST")) + try: + data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=False) + except RuntimeError as e: + self.assertTrue(str(e).startswith("can not find dataset directory")) + + shutil.rmtree(tempdir) + + +if __name__ == "__main__": + unittest.main()