diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index a895e8aa45..750b3fda20 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -10,7 +10,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np @@ -32,10 +32,12 @@ class PatchWSIDataset(Dataset): size: the size of patch to be extracted from the whole slide image. level: the level at which the patches to be extracted (default to 0). transform: transforms to be executed on input data. - reader: the module to be used for loading whole slide imaging, - - if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM. - - if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. - - if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + reader: the module to be used for loading whole slide imaging. If `reader` is + + - a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM. + - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. + - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class Note: @@ -45,14 +47,14 @@ class PatchWSIDataset(Dataset): [ {"image": "path/to/image1.tiff", "location": [200, 500], "label": 0}, - {"image": "path/to/image2.tiff", "location": [100, 700], "label": 1} + {"image": "path/to/image2.tiff", "location": [100, 700], "size": [20, 20], "level": 2, "label": 1} ] """ def __init__( self, - data: List, + data: Sequence, size: Optional[Union[int, Tuple[int, int]]] = None, level: Optional[int] = None, transform: Optional[Callable] = None, diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 02032a0ae6..8dee1f453e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -10,6 +10,7 @@ # limitations under the License. from abc import abstractmethod +from os.path import abspath from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -53,6 +54,7 @@ class BaseWSIReader(ImageReader): """ supported_suffixes: List[str] = [] + backend = "" def __init__(self, level: int, **kwargs): super().__init__() @@ -63,7 +65,7 @@ def __init__(self, level: int, **kwargs): @abstractmethod def get_size(self, wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -83,6 +85,11 @@ def get_level_count(self, wsi) -> int: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod def get_patch( self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str @@ -102,12 +109,14 @@ def get_patch( """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - @abstractmethod - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + def get_metadata( + self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int + ) -> Dict: """ Returns metadata of the extracted patch from the whole slide image. Args: + wsi: the whole slide image object, from which the patch is loaded patch: extracted patch from whole slide image location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). @@ -115,7 +124,14 @@ def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple level: the level number. Defaults to 0 """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + metadata: Dict = { + "backend": self.backend, + "original_channel_dim": 0, + "spatial_shape": np.asarray(patch.shape[1:]), + "wsi": {"path": self.get_file_path(wsi)}, + "patch": {"location": location, "size": size, "level": level}, + } + return metadata def get_data( self, @@ -194,8 +210,26 @@ def get_data( patch_list.append(patch) # Set patch-related metadata - each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level) - metadata.update(each_meta) + each_meta = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level) + + if len(wsi) == 1: + metadata = each_meta + else: + if not metadata: + metadata = { + "backend": each_meta["backend"], + "original_channel_dim": each_meta["original_channel_dim"], + "spatial_shape": each_meta["spatial_shape"], + "wsi": [each_meta["wsi"]], + "patch": [each_meta["patch"]], + } + else: + if metadata["original_channel_dim"] != each_meta["original_channel_dim"]: + raise ValueError("original_channel_dim is not consistent across wsi objects.") + if any(metadata["spatial_shape"] != each_meta["spatial_shape"]): + raise ValueError("spatial_shape is not consistent across wsi objects.") + metadata["wsi"].append(each_meta["wsi"]) + metadata["patch"].append(each_meta["patch"]) return _stack_images(patch_list, metadata), metadata @@ -247,7 +281,7 @@ def get_level_count(self, wsi) -> int: def get_size(self, wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -256,19 +290,9 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]: """ return self.reader.get_size(wsi, level) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - return self.reader.get_metadata(patch=patch, size=size, location=location, level=level) + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return self.reader.get_file_path(wsi) def get_patch( self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str @@ -317,6 +341,7 @@ class CuCIMWSIReader(BaseWSIReader): """ supported_suffixes = ["tif", "tiff", "svs"] + backend = "cucim" def __init__(self, level: int = 0, **kwargs): super().__init__(level, **kwargs) @@ -335,7 +360,7 @@ def get_level_count(wsi) -> int: @staticmethod def get_size(wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -344,27 +369,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - metadata: Dict = { - "backend": "cucim", - "spatial_shape": np.asarray(patch.shape[1:]), - "original_channel_dim": 0, - "location": location, - "size": size, - "level": level, - } - return metadata + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return str(abspath(wsi.path)) def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ @@ -440,6 +447,7 @@ class OpenSlideWSIReader(BaseWSIReader): """ supported_suffixes = ["tif", "tiff", "svs"] + backend = "openslide" def __init__(self, level: int = 0, **kwargs): super().__init__(level, **kwargs) @@ -458,7 +466,7 @@ def get_level_count(wsi) -> int: @staticmethod def get_size(wsi, level: int) -> Tuple[int, int]: """ - Returns the size of the whole slide image at a given level. + Returns the size (height, width) of the whole slide image at a given level. Args: wsi: a whole slide image object loaded from a file @@ -467,27 +475,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0]) - def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: - """ - Returns metadata of the extracted patch from the whole slide image. - - Args: - patch: extracted patch from whole slide image - location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). - size: (height, width) tuple giving the patch size at the given level (`level`). - If None, it is set to the full image size at the given level. - level: the level number. Defaults to 0 - - """ - metadata: Dict = { - "backend": "openslide", - "spatial_shape": np.asarray(patch.shape[1:]), - "original_channel_dim": 0, - "location": location, - "size": size, - "level": level, - } - return metadata + def get_file_path(self, wsi) -> str: + """Return the file path for the WSI object""" + return str(abspath(wsi._filename)) def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 2ac4125f97..4faec53978 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -125,8 +125,13 @@ class Tests(unittest.TestCase): def test_read_whole_image(self, file_path, level, expected_shape): reader = WSIReader(self.backend, level=level) with reader.read(file_path) as img_obj: - img = reader.get_data(img_obj)[0] + img, meta = reader.get_data(img_obj) self.assertTupleEqual(img.shape, expected_shape) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path))) + self.assertEqual(meta["patch"]["level"], level) + self.assertTupleEqual(meta["patch"]["size"], expected_shape[1:]) + self.assertTupleEqual(meta["patch"]["location"], (0, 0)) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_read_region(self, file_path, patch_info, expected_img): @@ -138,29 +143,39 @@ def test_read_region(self, file_path, patch_info, expected_img): reader.get_data(img_obj, **patch_info)[0] else: # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] + img, meta = reader.get_data(img_obj, **patch_info) img2 = reader.get_data(img_obj, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) self.assertIsNone(assert_array_equal(img, img2)) self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path))) + self.assertEqual(meta["patch"]["level"], patch_info["level"]) + self.assertTupleEqual(meta["patch"]["size"], expected_img.shape[1:]) + self.assertTupleEqual(meta["patch"]["location"], patch_info["location"]) @parameterized.expand([TEST_CASE_3]) - def test_read_region_multi_wsi(self, file_path, patch_info, expected_img): + def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img): kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} reader = WSIReader(self.backend, **kwargs) - img_obj = reader.read(file_path, **kwargs) + img_obj_list = reader.read(file_path_list, **kwargs) if self.backend == "tifffile": with self.assertRaises(ValueError): - reader.get_data(img_obj, **patch_info)[0] + reader.get_data(img_obj_list, **patch_info)[0] else: # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] - img2 = reader.get_data(img_obj, **patch_info)[0] + img, meta = reader.get_data(img_obj_list, **patch_info) + img2 = reader.get_data(img_obj_list, **patch_info)[0] self.assertTupleEqual(img.shape, img2.shape) self.assertIsNone(assert_array_equal(img, img2)) self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) + self.assertEqual(meta["backend"], self.backend) + self.assertEqual(meta["wsi"][0]["path"], str(os.path.abspath(file_path_list[0]))) + self.assertEqual(meta["patch"][0]["level"], patch_info["level"]) + self.assertTupleEqual(meta["patch"][0]["size"], expected_img.shape[1:]) + self.assertTupleEqual(meta["patch"][0]["location"], patch_info["location"]) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.")