From f6716e2ff295f5b981e3a833b817f203681a21e1 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 27 Apr 2022 18:12:07 +0000 Subject: [PATCH 01/18] Make all transforms optional Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 3c1fc0abed..cff93761ad 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -200,8 +200,8 @@ class PersistentDataset(Dataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]], + transform: Optional[Union[Sequence[Callable], Callable]] = None, hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, @@ -374,9 +374,9 @@ class CacheNTransDataset(PersistentDataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], cache_n_trans: int, cache_dir: Optional[Union[Path, str]], + transform: Optional[Union[Sequence[Callable], Callable]] = None, hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, @@ -476,7 +476,7 @@ class LMDBDataset(PersistentDataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], + transform: Optional[Union[Sequence[Callable], Callable]] = None, cache_dir: Union[Path, str] = "cache", hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", @@ -669,7 +669,7 @@ class CacheDataset(Dataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], + transform: Optional[Union[Sequence[Callable], Callable]] = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: Optional[int] = 1, @@ -883,8 +883,8 @@ class SmartCacheDataset(Randomizable, CacheDataset): def __init__( self, data: Sequence, - transform: Union[Sequence[Callable], Callable], replace_rate: float, + transform: Optional[Union[Sequence[Callable], Callable]] = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_init_workers: Optional[int] = 1, From f21fe5a323206b0010c4bf6cf377a35d9a36e59e Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 28 Apr 2022 14:48:50 +0000 Subject: [PATCH 02/18] Update wsireader tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_wsireader.py | 28 +++++++++------------------- tests/test_wsireader_new.py | 16 ++++++++++------ 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 3655100dab..5d092c4ce5 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -84,7 +84,9 @@ TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW -TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel +TEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)] # one color channel +TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color @@ -106,20 +108,6 @@ def save_rgba_tiff(array: np.ndarray, filename: str, mode: str): return filename -def save_gray_tiff(array: np.ndarray, filename: str): - """ - Save numpy array into a TIFF file - - Args: - array: numpy ndarray with any shape - filename: the filename to be used for the tiff file. - """ - img_gray = array - imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack") - - return filename - - @skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") def setUpModule(): # noqa: N802 hash_type = testing_data_config("images", FILE_KEY, "hash_type") @@ -187,13 +175,15 @@ def test_read_rgba(self, img_expected): self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) - @parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D]) + @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D]) @skipUnless(has_tiff, "Requires tifffile.") def test_read_malformats(self, img_expected): + if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1): + # Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230 + return reader = WSIReader(self.backend) - file_path = save_gray_tiff( - img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") - ) + file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") + imwrite(file_path, img_expected, shape=img_expected.shape) with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)): with reader.read(file_path) as img_obj: reader.get_data(img_obj) diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 63d61dfeb3..2ac4125f97 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -72,7 +72,9 @@ TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW -TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel +TEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)] # no color channel +TEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)] # one color channel +TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color @@ -103,7 +105,7 @@ def save_gray_tiff(array: np.ndarray, filename: str): filename: the filename to be used for the tiff file. """ img_gray = array - imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack") + imwrite(filename, img_gray, shape=img_gray.shape) return filename @@ -180,13 +182,15 @@ def test_read_rgba(self, img_expected): self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) - @parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D]) + @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D]) @skipUnless(has_tiff, "Requires tifffile.") def test_read_malformats(self, img_expected): + if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1): + # Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230 + return reader = WSIReader(self.backend) - file_path = save_gray_tiff( - img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") - ) + file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff") + imwrite(file_path, img_expected, shape=img_expected.shape) with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)): with reader.read(file_path) as img_obj: reader.get_data(img_obj) From 1610bbc65bfcf0f3d832c148b3d69ecd4019fa25 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 28 Apr 2022 14:59:43 +0000 Subject: [PATCH 03/18] Remove optional from PersistentDataset and its derivatives Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index cff93761ad..c8b847bc36 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -200,8 +200,8 @@ class PersistentDataset(Dataset): def __init__( self, data: Sequence, + transform: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]], - transform: Optional[Union[Sequence[Callable], Callable]] = None, hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, @@ -374,9 +374,9 @@ class CacheNTransDataset(PersistentDataset): def __init__( self, data: Sequence, + transform: Union[Sequence[Callable], Callable], cache_n_trans: int, cache_dir: Optional[Union[Path, str]], - transform: Optional[Union[Sequence[Callable], Callable]] = None, hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, @@ -476,7 +476,7 @@ class LMDBDataset(PersistentDataset): def __init__( self, data: Sequence, - transform: Optional[Union[Sequence[Callable], Callable]] = None, + transform: Union[Sequence[Callable], Callable], cache_dir: Union[Path, str] = "cache", hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", From 3d9516d38a417197b5b282ec86dd6887baeabfbe Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 28 Apr 2022 15:26:43 +0000 Subject: [PATCH 04/18] Add unittests for cache without transform Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_cachedataset.py | 6 ++++++ tests/test_smartcachedataset.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 4b77d4a55a..4fa1b5ea69 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -55,6 +55,12 @@ def test_shape(self, transform, expected_shape): data4 = dataset[-1] self.assertEqual(len(data3), 1) + if transform is None: + # Check without providing transfrom + dataset2 = CacheDataset(data=test_data, cache_rate=0.5, as_contiguous=True) + for k in ["image", "label", "extra"]: + self.assertEqual(dataset[0][k], dataset2[0][k]) + if transform is None: self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index e7d51be63a..6eca6113f0 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -56,6 +56,17 @@ def test_shape(self, replace_rate, num_replace_workers, transform): num_init_workers=4, num_replace_workers=num_replace_workers, ) + if transform is None: + # Check without providing transfrom + dataset2 = SmartCacheDataset( + data=test_data, + replace_rate=replace_rate, + cache_num=16, + num_init_workers=4, + num_replace_workers=num_replace_workers, + ) + for k in ["image", "label", "extra"]: + self.assertEqual(dataset[0][k], dataset2[0][k]) self.assertEqual(len(dataset._cache), dataset.cache_num) for i in range(dataset.cache_num): From a9f24c5d04e0a2ffcb9be563754e312f23df30b8 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 28 Apr 2022 17:14:27 +0000 Subject: [PATCH 05/18] Add default replace_rate Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c8b847bc36..42ac3b8f99 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -669,7 +669,7 @@ class CacheDataset(Dataset): def __init__( self, data: Sequence, - transform: Optional[Union[Sequence[Callable], Callable]] = None, + transform: Optional[Union[Sequence[Callable], Callable]], cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: Optional[int] = 1, @@ -883,8 +883,8 @@ class SmartCacheDataset(Randomizable, CacheDataset): def __init__( self, data: Sequence, - replace_rate: float, transform: Optional[Union[Sequence[Callable], Callable]] = None, + replace_rate: float = 0.5, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_init_workers: Optional[int] = 1, From 0d6450b0fb4bc6920a241b8f743a44209db78aa1 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 28 Apr 2022 19:54:50 +0000 Subject: [PATCH 06/18] Add default value Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 42ac3b8f99..54e8e9fdb1 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -669,7 +669,7 @@ class CacheDataset(Dataset): def __init__( self, data: Sequence, - transform: Optional[Union[Sequence[Callable], Callable]], + transform: Optional[Union[Sequence[Callable], Callable]] = None, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: Optional[int] = 1, From 20c4882ae69c490cd841e3121a54cd1fb1e8e774 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 29 Apr 2022 12:10:56 +0000 Subject: [PATCH 07/18] Set default replace_rate to 0.1 Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 54e8e9fdb1..29342742cb 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -856,7 +856,7 @@ class SmartCacheDataset(Randomizable, CacheDataset): Args: data: input data to load and transform to generate dataset for model. transform: transforms to execute operations on input data. - replace_rate: percentage of the cached items to be replaced in every epoch. + replace_rate: percentage of the cached items to be replaced in every epoch (default to 0.1). cache_num: number of items to be cached. Default is `sys.maxsize`. will take the minimum of (cache_num, data_length x cache_rate, data_length). cache_rate: percentage of cached data in total, default is 1.0 (cache all). @@ -884,7 +884,7 @@ def __init__( self, data: Sequence, transform: Optional[Union[Sequence[Callable], Callable]] = None, - replace_rate: float = 0.5, + replace_rate: float = 0.1, cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_init_workers: Optional[int] = 1, From f586386bc01634d901bdd4ab74db32953ffb77be Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 2 May 2022 14:37:19 +0000 Subject: [PATCH 08/18] Update metadata to include path Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_reader.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 02032a0ae6..0ef14d18dc 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -63,7 +63,7 @@ def __init__(self, level: int, **kwargs): @abstractmethod def get_size(self, wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -103,11 +103,14 @@ def get_patch( raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + def get_metadata( + self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int + ) -> Dict: """ Returns metadata of the extracted patch from the whole slide image. Args: + wsi: the whole slide image object, from which the patch is loaded patch: extracted patch from whole slide image location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). @@ -194,7 +197,7 @@ def get_data( patch_list.append(patch) # Set patch-related metadata - each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level) + each_meta = self.get_metadata(wsi=wsi, patch=patch, location=location, size=size, level=level) metadata.update(each_meta) return _stack_images(patch_list, metadata), metadata @@ -247,7 +250,7 @@ def get_level_count(self, wsi) -> int: def get_size(self, wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -256,11 +259,14 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]: """ return self.reader.get_size(wsi, level) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + def get_metadata( + self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int + ) -> Dict: """ Returns metadata of the extracted patch from the whole slide image. Args: + wsi: the whole slide image object, from which the patch is loaded patch: extracted patch from whole slide image location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). @@ -268,7 +274,7 @@ def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple level: the level number. Defaults to 0 """ - return self.reader.get_metadata(patch=patch, size=size, location=location, level=level) + return self.reader.get_metadata(wsi=wsi, patch=patch, size=size, location=location, level=level) def get_patch( self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str @@ -335,7 +341,7 @@ def get_level_count(wsi) -> int: @staticmethod def get_size(wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -344,11 +350,14 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + def get_metadata( + self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int + ) -> Dict: """ Returns metadata of the extracted patch from the whole slide image. Args: + wsi: the whole slide image object, from which the patch is loaded patch: extracted patch from whole slide image location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). @@ -358,6 +367,7 @@ def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple """ metadata: Dict = { "backend": "cucim", + "wsi_path": wsi.path, "spatial_shape": np.asarray(patch.shape[1:]), "original_channel_dim": 0, "location": location, @@ -458,7 +468,7 @@ def get_level_count(wsi) -> int: @staticmethod def get_size(wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -467,11 +477,14 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0]) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + def get_metadata( + self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int + ) -> Dict: """ Returns metadata of the extracted patch from the whole slide image. Args: + wsi: the whole slide image object, from which the patch is loaded patch: extracted patch from whole slide image location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). From 9dc6ca4a8bf3484db5b29d04e7a3423513913c8e Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 2 May 2022 14:38:08 +0000 Subject: [PATCH 09/18] Adds SmartCachePatchWSIDataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_datasets.py | 89 ++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 9 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index a895e8aa45..f01bb17244 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -10,16 +10,18 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union +import sys +from itertools import product +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np -from monai.data import Dataset +from monai.data import Dataset, SmartCacheDataset from monai.data.wsi_reader import BaseWSIReader, WSIReader from monai.transforms import apply_transform from monai.utils import ensure_tuple_rep -__all__ = ["PatchWSIDataset"] +__all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset"] class PatchWSIDataset(Dataset): @@ -32,10 +34,12 @@ class PatchWSIDataset(Dataset): size: the size of patch to be extracted from the whole slide image. level: the level at which the patches to be extracted (default to 0). transform: transforms to be executed on input data. - reader: the module to be used for loading whole slide imaging, - - if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM. - - if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. - - if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + reader: the module to be used for loading whole slide imaging. If `reader` is + + - a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM. + - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. + - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class Note: @@ -45,14 +49,14 @@ class PatchWSIDataset(Dataset): [ {"image": "path/to/image1.tiff", "location": [200, 500], "label": 0}, - {"image": "path/to/image2.tiff", "location": [100, 700], "label": 1} + {"image": "path/to/image2.tiff", "location": [100, 700], "size": [20, 20], "level": 2, "label": 1} ] """ def __init__( self, - data: List, + data: Sequence, size: Optional[Union[int, Tuple[int, int]]] = None, level: Optional[int] = None, transform: Optional[Callable] = None, @@ -133,3 +137,70 @@ def _transform(self, index: int): # Create put all patch information together and apply transforms patch = {"image": image, "label": label, "metadata": metadata} return apply_transform(self.transform, patch) if self.transform else patch + + +class SmartCachePatchWSIDataset(SmartCacheDataset): + """Add SmartCache functionality to `PatchWSIDataset`. + + Args: + data: the list of input samples including image, location, and label (see the note below for more details). + size: the size of patch to be extracted from the whole slide image. + level: the level at which the patches to be extracted (default to 0). + transform: transforms to be executed on input data. + reader_name: the name of library to be used for loading whole slide imaging, as the backend of `monai.data.WSIReader` + Defaults to CuCIM. + replace_rate: percentage of the cached items to be replaced in every epoch. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_init_workers: the number of worker threads to initialize the cache for first epoch. + If num_init_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is specified, 1 will be used instead. + num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. + If num_replace_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is specified, 1 will be used instead. + progress: whether to display a progress bar when caching for the first epoch. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. + kwargs: additional parameters for ``WSIReader`` + + """ + + def __init__( + self, + data: Sequence, + size: Optional[Union[int, Tuple[int, int]]] = None, + level: Optional[int] = None, + transform: Optional[Union[Sequence[Callable], Callable]] = None, + reader="cuCIM", + replace_rate: float = 0.1, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_init_workers: Optional[int] = 1, + num_replace_workers: Optional[int] = 1, + progress: bool = True, + shuffle: bool = False, + seed: int = 0, + copy_cache: bool = True, + as_contiguous: bool = True, + ): + patch_wsi_dataset = PatchWSIDataset(data=data, size=size, level=level, reader=reader) + super().__init__( + data=patch_wsi_dataset, # type: ignore + transform=transform, + replace_rate=replace_rate, + cache_num=cache_num, + cache_rate=cache_rate, + num_init_workers=num_init_workers, + num_replace_workers=num_replace_workers, + progress=progress, + shuffle=shuffle, + seed=seed, + copy_cache=copy_cache, + as_contiguous=as_contiguous, + ) From 6ad1b305476d7d071ec63dffa8e07fbed2679281 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 2 May 2022 15:26:57 +0000 Subject: [PATCH 10/18] Add unittests for SmartCachePatchWSIDataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../test_smartcache_patch_wsi_dataset_new.py | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 tests/test_smartcache_patch_wsi_dataset_new.py diff --git a/tests/test_smartcache_patch_wsi_dataset_new.py b/tests/test_smartcache_patch_wsi_dataset_new.py new file mode 100644 index 0000000000..351677b1a3 --- /dev/null +++ b/tests/test_smartcache_patch_wsi_dataset_new.py @@ -0,0 +1,171 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest import skipUnless + +import numpy as np +from numpy.testing import assert_array_equal +from parameterized import parameterized + +from monai.data.wsi_datasets import SmartCachePatchWSIDataset +from monai.utils import optional_import +from tests.utils import download_url_or_skip_test, testing_data_config + +_cucim, has_cim = optional_import("cucim") +has_cim = has_cim and hasattr(_cucim, "CuImage") + +FILE_KEY = "wsi_img" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) + +TEST_CASE_0 = [ + { + "data": [ + {"image": FILE_PATH, "location": [0, 0], "label": [0]}, + {"image": FILE_PATH, "location": [0, 0], "label": [1]}, + {"image": FILE_PATH, "location": [0, 0], "label": [2]}, + {"image": FILE_PATH, "location": [0, 0], "label": [3]}, + ], + "size": (1, 1), + "transform": lambda x: x, + "reader": "cuCIM", + "replace_rate": 0.5, + "cache_num": 2, + "num_init_workers": 1, + "num_replace_workers": 1, + "copy_cache": False, + }, + [ + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([0])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([2])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([2])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([3])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([3])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([0])}, + ], +] + +TEST_CASE_1 = [ + { + "data": [ + {"image": FILE_PATH, "location": [0, 0], "label": [[0, 0]]}, + {"image": FILE_PATH, "location": [0, 0], "label": [[1, 1]]}, + {"image": FILE_PATH, "location": [0, 0], "label": [[2, 2]]}, + ], + "size": (1, 1), + "transform": lambda x: x, + "reader": "cuCIM", + "replace_rate": 0.5, + "cache_num": 2, + "num_init_workers": 1, + "num_replace_workers": 1, + }, + [ + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[0, 0]])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[1, 1]])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[1, 1]])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[2, 2]])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[2, 2]])}, + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[0, 0]])}, + ], +] + +TEST_CASE_2 = [ + { + "data": [ + {"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 0]}, + {"image": FILE_PATH, "location": [10004, 20004], "label": [1, 1, 1, 1]}, + {"image": FILE_PATH, "location": [10004, 20004], "label": [2, 2, 2, 2]}, + ], + "size": (2, 2), + "transform": lambda x: x, + "reader": "cuCIM", + "replace_rate": 0.5, + "cache_num": 1, + "num_init_workers": 1, + "num_replace_workers": 1, + }, + [ + { + "image": np.array( + [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 + ), + "label": np.array([0, 0, 0, 0]), + }, + { + "image": np.array( + [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 + ), + "label": np.array([1, 1, 1, 1]), + }, + { + "image": np.array( + [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 + ), + "label": np.array([2, 2, 2, 2]), + }, + { + "image": np.array( + [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 + ), + "label": np.array([0, 0, 0, 0]), + }, + { + "image": np.array( + [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 + ), + "label": np.array([1, 1, 1, 1]), + }, + { + "image": np.array( + [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 + ), + "label": np.array([2, 2, 2, 2]), + }, + ], +] + + +class TestSmartCachePatchWSIDataset(unittest.TestCase): + def setUp(self): + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) + + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @skipUnless(has_cim, "Requires CuCIM") + def test_read_patches(self, input_parameters, expected): + dataset = SmartCachePatchWSIDataset(**input_parameters) + self.assertEqual(len(dataset), input_parameters["cache_num"]) + total_num_samples = len(input_parameters["data"]) + num_epochs = int( + np.ceil(total_num_samples / (input_parameters["cache_num"] * input_parameters["replace_rate"])) + ) + + dataset.start() + cache_num = input_parameters["cache_num"] + for i in range(num_epochs): + for j in range(len(dataset)): + self.assertTupleEqual(dataset[j]["label"].shape, expected[i * cache_num + j]["label"].shape) + self.assertTupleEqual(dataset[j]["image"].shape, expected[i * cache_num + j]["image"].shape) + self.assertIsNone(assert_array_equal(dataset[j]["label"], expected[i * cache_num + j]["label"])) + self.assertIsNone(assert_array_equal(dataset[j]["image"], expected[i * cache_num + j]["image"])) + dataset.update_cache() + dataset.shutdown() + + +if __name__ == "__main__": + unittest.main() From c54875038ce86c47e7f7bddb4db22cda12e590a9 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 2 May 2022 15:28:08 +0000 Subject: [PATCH 11/18] Update references Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/__init__.py | 2 +- monai/data/wsi_datasets.py | 3 ++- monai/data/wsi_reader.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index d9af568508..f4842ce8a0 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -87,5 +87,5 @@ worker_init_fn, zoom_affine, ) -from .wsi_datasets import PatchWSIDataset +from .wsi_datasets import PatchWSIDataset, SmartCachePatchWSIDataset from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, WSIReader diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index f01bb17244..9e19bc04aa 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -188,8 +188,9 @@ def __init__( seed: int = 0, copy_cache: bool = True, as_contiguous: bool = True, + **kwargs, ): - patch_wsi_dataset = PatchWSIDataset(data=data, size=size, level=level, reader=reader) + patch_wsi_dataset = PatchWSIDataset(data=data, size=size, level=level, reader=reader, **kwargs) super().__init__( data=patch_wsi_dataset, # type: ignore transform=transform, diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 0ef14d18dc..cb7ca5bc99 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -197,7 +197,7 @@ def get_data( patch_list.append(patch) # Set patch-related metadata - each_meta = self.get_metadata(wsi=wsi, patch=patch, location=location, size=size, level=level) + each_meta = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level) metadata.update(each_meta) return _stack_images(patch_list, metadata), metadata From c874f8e9845d889a2718686514f98f3139e2eca8 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 2 May 2022 15:34:11 +0000 Subject: [PATCH 12/18] Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/data.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/data.rst b/docs/source/data.rst index 02e8031117..60c681e065 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -311,3 +311,8 @@ PatchWSIDataset ~~~~~~~~~~~~~~~ .. autoclass:: monai.data.PatchWSIDataset :members: + +SmartCachePatchWSIDataset +~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.SmartCachePatchWSIDataset + :members: From 21e08c4b3307205f8773749de719dd6351248459 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 2 May 2022 16:49:57 +0000 Subject: [PATCH 13/18] Remove smart cache Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/data.rst | 5 - monai/data/__init__.py | 2 +- monai/data/wsi_datasets.py | 70 +------ .../test_smartcache_patch_wsi_dataset_new.py | 171 ------------------ 4 files changed, 2 insertions(+), 246 deletions(-) delete mode 100644 tests/test_smartcache_patch_wsi_dataset_new.py diff --git a/docs/source/data.rst b/docs/source/data.rst index 60c681e065..02e8031117 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -311,8 +311,3 @@ PatchWSIDataset ~~~~~~~~~~~~~~~ .. autoclass:: monai.data.PatchWSIDataset :members: - -SmartCachePatchWSIDataset -~~~~~~~~~~~~~~~ -.. autoclass:: monai.data.SmartCachePatchWSIDataset - :members: diff --git a/monai/data/__init__.py b/monai/data/__init__.py index f4842ce8a0..d9af568508 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -87,5 +87,5 @@ worker_init_fn, zoom_affine, ) -from .wsi_datasets import PatchWSIDataset, SmartCachePatchWSIDataset +from .wsi_datasets import PatchWSIDataset from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, WSIReader diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 9e19bc04aa..ace2470a6b 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -21,7 +21,7 @@ from monai.transforms import apply_transform from monai.utils import ensure_tuple_rep -__all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset"] +__all__ = ["PatchWSIDataset"] class PatchWSIDataset(Dataset): @@ -137,71 +137,3 @@ def _transform(self, index: int): # Create put all patch information together and apply transforms patch = {"image": image, "label": label, "metadata": metadata} return apply_transform(self.transform, patch) if self.transform else patch - - -class SmartCachePatchWSIDataset(SmartCacheDataset): - """Add SmartCache functionality to `PatchWSIDataset`. - - Args: - data: the list of input samples including image, location, and label (see the note below for more details). - size: the size of patch to be extracted from the whole slide image. - level: the level at which the patches to be extracted (default to 0). - transform: transforms to be executed on input data. - reader_name: the name of library to be used for loading whole slide imaging, as the backend of `monai.data.WSIReader` - Defaults to CuCIM. - replace_rate: percentage of the cached items to be replaced in every epoch. - cache_num: number of items to be cached. Default is `sys.maxsize`. - will take the minimum of (cache_num, data_length x cache_rate, data_length). - cache_rate: percentage of cached data in total, default is 1.0 (cache all). - will take the minimum of (cache_num, data_length x cache_rate, data_length). - num_init_workers: the number of worker threads to initialize the cache for first epoch. - If num_init_workers is None then the number returned by os.cpu_count() is used. - If a value less than 1 is specified, 1 will be used instead. - num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. - If num_replace_workers is None then the number returned by os.cpu_count() is used. - If a value less than 1 is specified, 1 will be used instead. - progress: whether to display a progress bar when caching for the first epoch. - copy_cache: whether to `deepcopy` the cache content before applying the random transforms, - default to `True`. if the random transforms don't modify the cache content - or every cache item is only used once in a `multi-processing` environment, - may set `copy=False` for better performance. - as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. - it may help improve the performance of following logic. - kwargs: additional parameters for ``WSIReader`` - - """ - - def __init__( - self, - data: Sequence, - size: Optional[Union[int, Tuple[int, int]]] = None, - level: Optional[int] = None, - transform: Optional[Union[Sequence[Callable], Callable]] = None, - reader="cuCIM", - replace_rate: float = 0.1, - cache_num: int = sys.maxsize, - cache_rate: float = 1.0, - num_init_workers: Optional[int] = 1, - num_replace_workers: Optional[int] = 1, - progress: bool = True, - shuffle: bool = False, - seed: int = 0, - copy_cache: bool = True, - as_contiguous: bool = True, - **kwargs, - ): - patch_wsi_dataset = PatchWSIDataset(data=data, size=size, level=level, reader=reader, **kwargs) - super().__init__( - data=patch_wsi_dataset, # type: ignore - transform=transform, - replace_rate=replace_rate, - cache_num=cache_num, - cache_rate=cache_rate, - num_init_workers=num_init_workers, - num_replace_workers=num_replace_workers, - progress=progress, - shuffle=shuffle, - seed=seed, - copy_cache=copy_cache, - as_contiguous=as_contiguous, - ) diff --git a/tests/test_smartcache_patch_wsi_dataset_new.py b/tests/test_smartcache_patch_wsi_dataset_new.py deleted file mode 100644 index 351677b1a3..0000000000 --- a/tests/test_smartcache_patch_wsi_dataset_new.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest -from unittest import skipUnless - -import numpy as np -from numpy.testing import assert_array_equal -from parameterized import parameterized - -from monai.data.wsi_datasets import SmartCachePatchWSIDataset -from monai.utils import optional_import -from tests.utils import download_url_or_skip_test, testing_data_config - -_cucim, has_cim = optional_import("cucim") -has_cim = has_cim and hasattr(_cucim, "CuImage") - -FILE_KEY = "wsi_img" -FILE_URL = testing_data_config("images", FILE_KEY, "url") -base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" -FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) - -TEST_CASE_0 = [ - { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [0]}, - {"image": FILE_PATH, "location": [0, 0], "label": [1]}, - {"image": FILE_PATH, "location": [0, 0], "label": [2]}, - {"image": FILE_PATH, "location": [0, 0], "label": [3]}, - ], - "size": (1, 1), - "transform": lambda x: x, - "reader": "cuCIM", - "replace_rate": 0.5, - "cache_num": 2, - "num_init_workers": 1, - "num_replace_workers": 1, - "copy_cache": False, - }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([0])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([2])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([2])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([3])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([3])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([0])}, - ], -] - -TEST_CASE_1 = [ - { - "data": [ - {"image": FILE_PATH, "location": [0, 0], "label": [[0, 0]]}, - {"image": FILE_PATH, "location": [0, 0], "label": [[1, 1]]}, - {"image": FILE_PATH, "location": [0, 0], "label": [[2, 2]]}, - ], - "size": (1, 1), - "transform": lambda x: x, - "reader": "cuCIM", - "replace_rate": 0.5, - "cache_num": 2, - "num_init_workers": 1, - "num_replace_workers": 1, - }, - [ - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[0, 0]])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[1, 1]])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[1, 1]])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[2, 2]])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[2, 2]])}, - {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[0, 0]])}, - ], -] - -TEST_CASE_2 = [ - { - "data": [ - {"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 0]}, - {"image": FILE_PATH, "location": [10004, 20004], "label": [1, 1, 1, 1]}, - {"image": FILE_PATH, "location": [10004, 20004], "label": [2, 2, 2, 2]}, - ], - "size": (2, 2), - "transform": lambda x: x, - "reader": "cuCIM", - "replace_rate": 0.5, - "cache_num": 1, - "num_init_workers": 1, - "num_replace_workers": 1, - }, - [ - { - "image": np.array( - [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 - ), - "label": np.array([0, 0, 0, 0]), - }, - { - "image": np.array( - [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 - ), - "label": np.array([1, 1, 1, 1]), - }, - { - "image": np.array( - [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 - ), - "label": np.array([2, 2, 2, 2]), - }, - { - "image": np.array( - [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 - ), - "label": np.array([0, 0, 0, 0]), - }, - { - "image": np.array( - [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 - ), - "label": np.array([1, 1, 1, 1]), - }, - { - "image": np.array( - [[[247, 246], [246, 246]], [[245, 246], [246, 246]], [[246, 244], [244, 244]]], dtype=np.uint8 - ), - "label": np.array([2, 2, 2, 2]), - }, - ], -] - - -class TestSmartCachePatchWSIDataset(unittest.TestCase): - def setUp(self): - hash_type = testing_data_config("images", FILE_KEY, "hash_type") - hash_val = testing_data_config("images", FILE_KEY, "hash_val") - download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) - - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - @skipUnless(has_cim, "Requires CuCIM") - def test_read_patches(self, input_parameters, expected): - dataset = SmartCachePatchWSIDataset(**input_parameters) - self.assertEqual(len(dataset), input_parameters["cache_num"]) - total_num_samples = len(input_parameters["data"]) - num_epochs = int( - np.ceil(total_num_samples / (input_parameters["cache_num"] * input_parameters["replace_rate"])) - ) - - dataset.start() - cache_num = input_parameters["cache_num"] - for i in range(num_epochs): - for j in range(len(dataset)): - self.assertTupleEqual(dataset[j]["label"].shape, expected[i * cache_num + j]["label"].shape) - self.assertTupleEqual(dataset[j]["image"].shape, expected[i * cache_num + j]["image"].shape) - self.assertIsNone(assert_array_equal(dataset[j]["label"], expected[i * cache_num + j]["label"])) - self.assertIsNone(assert_array_equal(dataset[j]["image"], expected[i * cache_num + j]["image"])) - dataset.update_cache() - dataset.shutdown() - - -if __name__ == "__main__": - unittest.main() From a9de037c0e63c0ea78d43c91bc454a4960fab4cb Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 2 May 2022 16:55:06 +0000 Subject: [PATCH 14/18] Remove unused imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_datasets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index ace2470a6b..750b3fda20 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -10,13 +10,11 @@ # limitations under the License. import inspect -import sys -from itertools import product from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np -from monai.data import Dataset, SmartCacheDataset +from monai.data import Dataset from monai.data.wsi_reader import BaseWSIReader, WSIReader from monai.transforms import apply_transform from monai.utils import ensure_tuple_rep From 0e8879214ee9911ad3f5761dfc65290428e33e54 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 3 May 2022 13:58:15 +0000 Subject: [PATCH 15/18] Add path metadata for OpenSlide Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index cb7ca5bc99..03d81479e3 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -494,6 +494,7 @@ def get_metadata( """ metadata: Dict = { "backend": "openslide", + "wsi_path": wsi._filename, "spatial_shape": np.asarray(patch.shape[1:]), "original_channel_dim": 0, "location": location, From ffd439e9755a72958e4916d079cbd4eda43c97ea Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 3 May 2022 14:49:59 +0000 Subject: [PATCH 16/18] Update metadata to be unified across different backends Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_reader.py | 94 +++++++++++----------------------------- 1 file changed, 26 insertions(+), 68 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 03d81479e3..d45801f64f 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -10,6 +10,7 @@ # limitations under the License. from abc import abstractmethod +from os.path import abspath from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -53,6 +54,7 @@ class BaseWSIReader(ImageReader): """ supported_suffixes: List[str] = [] + backend = "" def __init__(self, level: int, **kwargs): super().__init__() @@ -83,6 +85,11 @@ def get_level_count(self, wsi) -> int: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod def get_patch( self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str @@ -102,7 +109,6 @@ def get_patch( """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - @abstractmethod def get_metadata( self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int ) -> Dict: @@ -118,7 +124,14 @@ def get_metadata( level: the level number. Defaults to 0 """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + metadata: Dict = { + "backend": self.backend, + "original_channel_dim": 0, + "spatial_shape": np.asarray(patch.shape[1:]), + "wsi": {"path": self.get_file_path(wsi)}, + "patch": {"location": location, "size": size, "level": level}, + } + return metadata def get_data( self, @@ -259,22 +272,9 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]: """ return self.reader.get_size(wsi, level) - def get_metadata( - self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int - ) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - wsi: the whole slide image object, from which the patch is loaded - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - return self.reader.get_metadata(wsi=wsi, patch=patch, size=size, location=location, level=level) + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return self.reader.get_file_path(wsi) def get_patch( self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str @@ -323,6 +323,7 @@ class CuCIMWSIReader(BaseWSIReader): """ supported_suffixes = ["tif", "tiff", "svs"] + backend = "cucim" def __init__(self, level: int = 0, **kwargs): super().__init__(level, **kwargs) @@ -350,31 +351,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) - def get_metadata( - self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int - ) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - wsi: the whole slide image object, from which the patch is loaded - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - metadata: Dict = { - "backend": "cucim", - "wsi_path": wsi.path, - "spatial_shape": np.asarray(patch.shape[1:]), - "original_channel_dim": 0, - "location": location, - "size": size, - "level": level, - } - return metadata + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return str(abspath(wsi.path)) def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ @@ -450,6 +429,7 @@ class OpenSlideWSIReader(BaseWSIReader): """ supported_suffixes = ["tif", "tiff", "svs"] + backend = "openslide" def __init__(self, level: int = 0, **kwargs): super().__init__(level, **kwargs) @@ -477,31 +457,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0]) - def get_metadata( - self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int - ) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - wsi: the whole slide image object, from which the patch is loaded - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - metadata: Dict = { - "backend": "openslide", - "wsi_path": wsi._filename, - "spatial_shape": np.asarray(patch.shape[1:]), - "original_channel_dim": 0, - "location": location, - "size": size, - "level": level, - } - return metadata + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return str(abspath(wsi._filename)) def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ From a7d2dcbdf07ac810dc625791d8e99b4cd6d9b5dc Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 6 May 2022 15:00:08 +0000 Subject: [PATCH 17/18] Update wsi metadata for multi wsi objects Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/data/wsi_reader.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index d45801f64f..8dee1f453e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -211,7 +211,25 @@ def get_data( # Set patch-related metadata each_meta = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level) - metadata.update(each_meta) + + if len(wsi) == 1: + metadata = each_meta + else: + if not metadata: + metadata = { + "backend": each_meta["backend"], + "original_channel_dim": each_meta["original_channel_dim"], + "spatial_shape": each_meta["spatial_shape"], + "wsi": [each_meta["wsi"]], + "patch": [each_meta["patch"]], + } + else: + if metadata["original_channel_dim"] != each_meta["original_channel_dim"]: + raise ValueError("original_channel_dim is not consistent across wsi objects.") + if any(metadata["spatial_shape"] != each_meta["spatial_shape"]): + raise ValueError("spatial_shape is not consistent across wsi objects.") + metadata["wsi"].append(each_meta["wsi"]) + metadata["patch"].append(each_meta["patch"]) return _stack_images(patch_list, metadata), metadata From 960c11726d9600296f12f06a9ffc672e7bc343d8 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 6 May 2022 15:00:25 +0000 Subject: [PATCH 18/18] Add unittests for wsi metadata Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_wsireader_new.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 2ac4125f97..4faec53978 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -125,8 +125,13 @@ class Tests(unittest.TestCase): def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReader(self.backend, level=level) with reader.read(file_path) as img_obj: - img = reader.get_data(img_obj)[0] + img, meta = reader.get_data(img_obj) self.assertTupleEqual(img.shape, expected_shape) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path))) + self.assertEqual(meta["patch"]["level"], level) + self.assertTupleEqual(meta["patch"]["size"], expected_shape[1:]) + self.assertTupleEqual(meta["patch"]["location"], (0, 0)) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_read_region(self, file_path, patch_info, expected_img): @@ -138,29 +143,39 @@ def test_read_region(self, file_path, patch_info, expected_img): reader.get_data(img_obj, **patch_info)[0] else: # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] + img, meta = reader.get_data(img_obj, **patch_info) img2 = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) self.assertIsNone(assert_array_equal(img, img2)) self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path))) + self.assertEqual(meta["patch"]["level"], patch_info["level"]) + self.assertTupleEqual(meta["patch"]["size"], expected_img.shape[1:]) + self.assertTupleEqual(meta["patch"]["location"], patch_info["location"]) @parameterized.expand([TEST_CASE_3]) - def test_read_region_multi_wsi(self, file_path, patch_info, expected_img): + def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img): kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} reader = WSIReader(self.backend, **kwargs) - img_obj = reader.read(file_path, **kwargs) + img_obj_list = reader.read(file_path_list, **kwargs) if self.backend == "tifffile": with self.assertRaises(ValueError): - reader.get_data(img_obj, **patch_info)[0] + reader.get_data(img_obj_list, **patch_info)[0] else: # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] - img2 = reader.get_data(img_obj, **patch_info)[0] + img, meta = reader.get_data(img_obj_list, **patch_info) + img2 = reader.get_data(img_obj_list, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) self.assertIsNone(assert_array_equal(img, img2)) self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"][0]["path"], str(os.path.abspath(file_path_list[0]))) + self.assertEqual(meta["patch"][0]["level"], patch_info["level"]) + self.assertTupleEqual(meta["patch"][0]["size"], expected_img.shape[1:]) + self.assertTupleEqual(meta["patch"][0]["location"], patch_info["location"]) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.")