Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f6716e2
Make all transforms optional
bhashemian Apr 27, 2022
6506edd
Merge branch 'dev' into fix-smartcache
bhashemian Apr 27, 2022
ee59746
Merge branch 'dev' into fix-smartcache
bhashemian Apr 28, 2022
aa08be9
Merge branch 'dev' of github.com:Project-MONAI/MONAI into fix-smartcache
bhashemian Apr 28, 2022
f21fe5a
Update wsireader tests
bhashemian Apr 28, 2022
1610bbc
Remove optional from PersistentDataset and its derivatives
bhashemian Apr 28, 2022
3d9516d
Add unittests for cache without transform
bhashemian Apr 28, 2022
b1656f8
Merge branch 'fix-smartcache' of github.com:behxyz/MONAI into fix-sma…
bhashemian Apr 28, 2022
a9f24c5
Add default replace_rate
bhashemian Apr 28, 2022
0d6450b
Add default value
bhashemian Apr 28, 2022
20c4882
Set default replace_rate to 0.1
bhashemian Apr 29, 2022
7096ad8
Merge branch 'dev' into fix-smartcache
wyli Apr 29, 2022
f586386
Update metadata to include path
bhashemian May 2, 2022
9dc6ca4
Adds SmartCachePatchWSIDataset
bhashemian May 2, 2022
6ad1b30
Add unittests for SmartCachePatchWSIDataset
bhashemian May 2, 2022
c548750
Update references
bhashemian May 2, 2022
e52d846
Merge branch 'fix-smartcache' of github.com:behxyz/MONAI into fix-sma…
bhashemian May 2, 2022
bee7f87
Merge branch 'dev' of github.com:Project-MONAI/MONAI into fix-smartcache
bhashemian May 2, 2022
c874f8e
Update docs
bhashemian May 2, 2022
21e08c4
Remove smart cache
bhashemian May 2, 2022
a9de037
Remove unused imports
bhashemian May 2, 2022
0e88792
Add path metadata for OpenSlide
bhashemian May 3, 2022
ffd439e
Update metadata to be unified across different backends
bhashemian May 3, 2022
408519a
Merge branch 'dev' into fix-smartcache
bhashemian May 4, 2022
d4a8f55
Merge branch 'dev' of github.com:Project-MONAI/MONAI into fix-smartcache
bhashemian May 6, 2022
a7d2dcb
Update wsi metadata for multi wsi objects
bhashemian May 6, 2022
960c117
Add unittests for wsi metadata
bhashemian May 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -32,10 +32,12 @@ class PatchWSIDataset(Dataset):
size: the size of patch to be extracted from the whole slide image.
level: the level at which the patches to be extracted (default to 0).
transform: transforms to be executed on input data.
reader: the module to be used for loading whole slide imaging,
- if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
- if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
- if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
reader: the module to be used for loading whole slide imaging. If `reader` is

- a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
- an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.

kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class

Note:
Expand All @@ -45,14 +47,14 @@ class PatchWSIDataset(Dataset):

[
{"image": "path/to/image1.tiff", "location": [200, 500], "label": 0},
{"image": "path/to/image2.tiff", "location": [100, 700], "label": 1}
{"image": "path/to/image2.tiff", "location": [100, 700], "size": [20, 20], "level": 2, "label": 1}
]

"""

def __init__(
self,
data: List,
data: Sequence,
size: Optional[Union[int, Tuple[int, int]]] = None,
level: Optional[int] = None,
transform: Optional[Callable] = None,
Expand Down
118 changes: 54 additions & 64 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from abc import abstractmethod
from os.path import abspath
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -53,6 +54,7 @@ class BaseWSIReader(ImageReader):
"""

supported_suffixes: List[str] = []
backend = ""

def __init__(self, level: int, **kwargs):
super().__init__()
Expand All @@ -63,7 +65,7 @@ def __init__(self, level: int, **kwargs):
@abstractmethod
def get_size(self, wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -83,6 +85,11 @@ def get_level_count(self, wsi) -> int:
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
Expand All @@ -102,20 +109,29 @@ def get_patch(
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
def get_metadata(
self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int
) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.

Args:
wsi: the whole slide image object, from which the patch is loaded
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
metadata: Dict = {
"backend": self.backend,
"original_channel_dim": 0,
"spatial_shape": np.asarray(patch.shape[1:]),
"wsi": {"path": self.get_file_path(wsi)},
"patch": {"location": location, "size": size, "level": level},
}
return metadata

def get_data(
self,
Expand Down Expand Up @@ -194,8 +210,26 @@ def get_data(
patch_list.append(patch)

# Set patch-related metadata
each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level)
metadata.update(each_meta)
each_meta = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)

if len(wsi) == 1:
metadata = each_meta
else:
if not metadata:
metadata = {
"backend": each_meta["backend"],
"original_channel_dim": each_meta["original_channel_dim"],
"spatial_shape": each_meta["spatial_shape"],
"wsi": [each_meta["wsi"]],
"patch": [each_meta["patch"]],
}
else:
if metadata["original_channel_dim"] != each_meta["original_channel_dim"]:
raise ValueError("original_channel_dim is not consistent across wsi objects.")
if any(metadata["spatial_shape"] != each_meta["spatial_shape"]):
raise ValueError("spatial_shape is not consistent across wsi objects.")
metadata["wsi"].append(each_meta["wsi"])
metadata["patch"].append(each_meta["patch"])

return _stack_images(patch_list, metadata), metadata

Expand Down Expand Up @@ -247,7 +281,7 @@ def get_level_count(self, wsi) -> int:

def get_size(self, wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -256,19 +290,9 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]:
"""
return self.reader.get_size(wsi, level)

def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.

Args:
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0

"""
return self.reader.get_metadata(patch=patch, size=size, location=location, level=level)
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return self.reader.get_file_path(wsi)

def get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
Expand Down Expand Up @@ -317,6 +341,7 @@ class CuCIMWSIReader(BaseWSIReader):
"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "cucim"

def __init__(self, level: int = 0, **kwargs):
super().__init__(level, **kwargs)
Expand All @@ -335,7 +360,7 @@ def get_level_count(wsi) -> int:
@staticmethod
def get_size(wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -344,27 +369,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
"""
return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0])

def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.

Args:
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0

"""
metadata: Dict = {
"backend": "cucim",
"spatial_shape": np.asarray(patch.shape[1:]),
"original_channel_dim": 0,
"location": location,
"size": size,
"level": level,
}
return metadata
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi.path))

def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
Expand Down Expand Up @@ -440,6 +447,7 @@ class OpenSlideWSIReader(BaseWSIReader):
"""

supported_suffixes = ["tif", "tiff", "svs"]
backend = "openslide"

def __init__(self, level: int = 0, **kwargs):
super().__init__(level, **kwargs)
Expand All @@ -458,7 +466,7 @@ def get_level_count(wsi) -> int:
@staticmethod
def get_size(wsi, level: int) -> Tuple[int, int]:
"""
Returns the size of the whole slide image at a given level.
Returns the size (height, width) of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
Expand All @@ -467,27 +475,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
"""
return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0])

def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
"""
Returns metadata of the extracted patch from the whole slide image.

Args:
patch: extracted patch from whole slide image
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
size: (height, width) tuple giving the patch size at the given level (`level`).
If None, it is set to the full image size at the given level.
level: the level number. Defaults to 0

"""
metadata: Dict = {
"backend": "openslide",
"spatial_shape": np.asarray(patch.shape[1:]),
"original_channel_dim": 0,
"location": location,
"size": size,
"level": level,
}
return metadata
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi._filename))

def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
Expand Down
29 changes: 22 additions & 7 deletions tests/test_wsireader_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,13 @@ class Tests(unittest.TestCase):
def test_read_whole_image(self, file_path, level, expected_shape):
reader = WSIReader(self.backend, level=level)
with reader.read(file_path) as img_obj:
img = reader.get_data(img_obj)[0]
img, meta = reader.get_data(img_obj)
self.assertTupleEqual(img.shape, expected_shape)
self.assertEqual(meta["backend"], self.backend)
self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path)))
self.assertEqual(meta["patch"]["level"], level)
self.assertTupleEqual(meta["patch"]["size"], expected_shape[1:])
self.assertTupleEqual(meta["patch"]["location"], (0, 0))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_read_region(self, file_path, patch_info, expected_img):
Expand All @@ -138,29 +143,39 @@ def test_read_region(self, file_path, patch_info, expected_img):
reader.get_data(img_obj, **patch_info)[0]
else:
# Read twice to check multiple calls
img = reader.get_data(img_obj, **patch_info)[0]
img, meta = reader.get_data(img_obj, **patch_info)
img2 = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, img2.shape)
self.assertIsNone(assert_array_equal(img, img2))
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))
self.assertEqual(meta["backend"], self.backend)
self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path)))
self.assertEqual(meta["patch"]["level"], patch_info["level"])
self.assertTupleEqual(meta["patch"]["size"], expected_img.shape[1:])
self.assertTupleEqual(meta["patch"]["location"], patch_info["location"])

@parameterized.expand([TEST_CASE_3])
def test_read_region_multi_wsi(self, file_path, patch_info, expected_img):
def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img):
kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {}
reader = WSIReader(self.backend, **kwargs)
img_obj = reader.read(file_path, **kwargs)
img_obj_list = reader.read(file_path_list, **kwargs)
if self.backend == "tifffile":
with self.assertRaises(ValueError):
reader.get_data(img_obj, **patch_info)[0]
reader.get_data(img_obj_list, **patch_info)[0]
else:
# Read twice to check multiple calls
img = reader.get_data(img_obj, **patch_info)[0]
img2 = reader.get_data(img_obj, **patch_info)[0]
img, meta = reader.get_data(img_obj_list, **patch_info)
img2 = reader.get_data(img_obj_list, **patch_info)[0]
self.assertTupleEqual(img.shape, img2.shape)
self.assertIsNone(assert_array_equal(img, img2))
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))
self.assertEqual(meta["backend"], self.backend)
self.assertEqual(meta["wsi"][0]["path"], str(os.path.abspath(file_path_list[0])))
self.assertEqual(meta["patch"][0]["level"], patch_info["level"])
self.assertTupleEqual(meta["patch"][0]["size"], expected_img.shape[1:])
self.assertTupleEqual(meta["patch"][0]["location"], patch_info["location"])

@parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])
@skipUnless(has_tiff, "Requires tifffile.")
Expand Down