Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions monai/application/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import *
from .datasets import *
118 changes: 118 additions & 0 deletions monai/application/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tarfile
import numpy as np
from typing import Any, Callable
from monai.data import CacheDataset
from monai.transforms import Randomizable
from .utils import download_and_extract


class MedNISTDataset(Randomizable, CacheDataset):
"""
The Dataset to automatically download MedNIST data and generate items for training, validation or test.
It's based on `CacheDataset` to accelerate the training process.

Args:
root_dir: target directory to download and load MedNIST dataset.
section: expected data section, can be: `training`, `validation` or `test`.
transform: transforms to execute operations on input data.
download: whether to download and extract the MedNIST from resource link, default is False.
if expected file already exists, skip downloading even set it to True.
user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory.
extract: whether to extract the MedNIST.tar.gz file under root directory, default is False.
seed: random seed to randomly split training, validation and test datasets, defaut is 0.
val_frac: percentage of of validation fraction in the whole dataset, default is 0.1.
test_frac: percentage of of test fraction in the whole dataset, default is 0.1.
cache_num: number of items to be cached. Default is `sys.maxsize`.
will take the minimum of (cache_num, data_length x cache_rate, data_length).
cache_rate: percentage of cached data in total, default is 1.0 (cache all).
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
If 0 a single thread will be used. Default is 0.

"""

resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file_name = "MedNIST.tar.gz"
dataset_folder_name = "MedNIST"

def __init__(
self,
root_dir: str,
section: str,
transform: Callable[..., Any],
download: bool = False,
seed: int = 0,
val_frac: float = 0.1,
test_frac: float = 0.1,
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int = 0,
):
if not os.path.isdir(root_dir):
raise ValueError("root_dir must be a directory.")
self.section = section
self.val_frac = val_frac
self.test_frac = test_frac
self.set_random_state(seed=seed)
tarfile_name = os.path.join(root_dir, self.compressed_file_name)
dataset_dir = os.path.join(root_dir, self.dataset_folder_name)
if download:
download_and_extract(self.resource, tarfile_name, root_dir, self.md5)

if not os.path.exists(dataset_dir):
raise RuntimeError("can not find dataset directory, please use download=True to download it.")
data = self._generate_data_list(dataset_dir)
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)

def randomize(self):
self.rann = self.R.random()

def _generate_data_list(self, dataset_dir):
class_names = sorted([x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))])
num_class = len(class_names)
image_files = [
[
os.path.join(dataset_dir, class_names[i], x)
for x in os.listdir(os.path.join(dataset_dir, class_names[i]))
]
for i in range(num_class)
]
num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
image_files_list.extend(image_files[i])
image_class.extend([i] * num_each[i])
num_total = len(image_class)

data = list()

for i in range(num_total):
self.randomize()
if self.section == "training":
if self.rann < self.val_frac + self.test_frac:
continue
elif self.section == "validation":
if self.rann >= self.val_frac:
continue
elif self.section == "test":
if self.rann < self.val_frac or self.rann >= self.val_frac + self.test_frac:
continue
else:
raise ValueError("section name can only be: training, validation or test.")
data.append({"image": image_files_list[i], "label": image_class[i]})
return data
122 changes: 122 additions & 0 deletions monai/application/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from urllib.request import urlretrieve
from urllib.error import URLError
import hashlib
import tarfile
import zipfile
from monai.utils import process_bar


def check_md5(filepath: str, md5_value: str = None):
"""
check MD5 signature of specified file.

Args:
filepath: path of source file to verify MD5.
md5_value: expected MD5 value of the file.

"""
if md5_value is not None:
md5 = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
md5.update(chunk)
if md5_value != md5.hexdigest():
return False
else:
print(f"expected MD5 is None, skip MD5 check for file {filepath}.")

return True


def download_url(url: str, filepath: str, md5_value: str = None):
"""
Download file from specified URL link, support process bar and MD5 check.

Args:
url: source URL link to download file.
filepath: target filepath to save the downloaded file.
md5_value: expected MD5 value to validate the downloaded file.
if None, skip MD5 validation.

"""
if os.path.exists(filepath):
if not check_md5(filepath, md5_value):
raise RuntimeError(f"MD5 check of existing file {filepath} failed, please delete it and try again.")
print(f"file {filepath} exists, skip downloading.")
return
os.makedirs(os.path.dirname(filepath), exist_ok=True)

def _process_hook(blocknum, blocksize, totalsize):
process_bar(blocknum * blocksize, totalsize)

try:
urlretrieve(url, filepath, reporthook=_process_hook)
print(f"\ndownloaded file: {filepath}.")
except (URLError, IOError) as e:
raise e

if not check_md5(filepath, md5_value):
raise RuntimeError(
f"MD5 check of downloaded file failed, \
URL={url}, filepath={filepath}, expected MD5={md5_value}."
)


def extractall(filepath: str, output_dir: str, md5_value: str = None):
"""
Extract file to the output directory.
Expected file types are: `zip`, `tar.gz` and `tar`.

Args:
filepath: the file path of compressed file.
output_dir: target directory to save extracted files.
md5_value: expected MD5 value to validate the compressed file.
if None, skip MD5 validation.

"""
target_file = os.path.join(output_dir, os.path.basename(filepath).split(".")[0])
if os.path.exists(target_file):
print(f"extracted file {target_file} exists, skip extracting.")
return
if not check_md5(filepath, md5_value):
raise RuntimeError(f"MD5 check of compressed file {filepath} failed.")

if filepath.endswith("zip"):
zip_file = zipfile.ZipFile(filepath)
zip_file.extractall(output_dir)
zip_file.close()
elif filepath.endswith("tar") or filepath.endswith("tar.gz"):
tar_file = tarfile.open(filepath)
tar_file.extractall(output_dir)
tar_file.close()
else:
raise TypeError("unsupported compressed file type.")


def download_and_extract(url: str, filepath: str, output_dir: str, md5_value: str = None):
"""
Download file from URL and extract it to the output directory.

Args:
url: source URL link to download file.
filepath: the file path of compressed file.
output_dir: target directory to save extracted files.
defaut is None to save in current directory.
md5_value: expected MD5 value to validate the downloaded file.
if None, skip MD5 validation.

"""
download_url(url=url, filepath=filepath, md5_value=md5_value)
extractall(filepath=filepath, output_dir=output_dir, md5_value=md5_value)
43 changes: 43 additions & 0 deletions tests/test_check_md5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os
import shutil
import numpy as np
import tempfile
from PIL import Image
from parameterized import parameterized
from monai.application import check_md5

TEST_CASE_1 = ["f38e9e043c8e902321e827b24ce2e5ec", True]

TEST_CASE_2 = ["12c730d4e7427e00ad1c5526a6677535", False]

TEST_CASE_3 = [None, True]


class TestCheckMD5(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, md5_value, expected_result):
test_image = np.ones((64, 64, 3))
tempdir = tempfile.mkdtemp()
filename = os.path.join(tempdir, "test_file.png")
Image.fromarray(test_image.astype("uint8")).save(filename)

result = check_md5(filename, md5_value)
self.assertTrue(result == expected_result)

shutil.rmtree(tempdir)


if __name__ == "__main__":
unittest.main()
46 changes: 46 additions & 0 deletions tests/test_download_and_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os
import shutil
import numpy as np
import tempfile
from monai.application import download_and_extract, download_url, extractall


class TestDownloadAndExtract(unittest.TestCase):
def test_actions(self):
tempdir = tempfile.mkdtemp()
url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
filepath = os.path.join(tempdir, "MedNIST.tar.gz")
output_dir = tempdir
md5_value = "0bc7306e7427e00ad1c5526a6677552d"
download_and_extract(url, filepath, output_dir, md5_value)
download_and_extract(url, filepath, output_dir, md5_value)

wrong_md5 = "0"
try:
download_url(url, filepath, wrong_md5)
except RuntimeError as e:
self.assertTrue(str(e).startswith("MD5 check"))

shutil.rmtree(os.path.join(tempdir, "MedNIST"))
try:
extractall(filepath, output_dir, wrong_md5)
except RuntimeError as e:
self.assertTrue(str(e).startswith("MD5 check"))

shutil.rmtree(tempdir)


if __name__ == "__main__":
unittest.main()
55 changes: 55 additions & 0 deletions tests/test_mednistdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os
import shutil
import tempfile

from monai.application import MedNISTDataset
from tests.utils import NumpyImageTestCase2D
from monai.transforms import LoadPNGd, AddChanneld, ScaleIntensityd, ToTensord, Compose


class TestMedNISTDataset(unittest.TestCase):
def test_values(self):
tempdir = tempfile.mkdtemp()
transform = Compose(
[
LoadPNGd(keys="image"),
AddChanneld(keys="image"),
ScaleIntensityd(keys="image"),
ToTensord(keys=["image", "label"]),
]
)

def _test_dataset(dataset):
self.assertEqual(len(dataset), 5986)
self.assertTrue("image" in dataset[0])
self.assertTrue("label" in dataset[0])
self.assertTrue("image_meta_dict" in dataset[0])
self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=True)
_test_dataset(data)
data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=False)
_test_dataset(data)
shutil.rmtree(os.path.join(tempdir, "MedNIST"))
try:
data = MedNISTDataset(root_dir=tempdir, transform=transform, section="test", download=False)
except RuntimeError as e:
self.assertTrue(str(e).startswith("can not find dataset directory"))

shutil.rmtree(tempdir)


if __name__ == "__main__":
unittest.main()