diff --git a/docs/source/data.rst b/docs/source/data.rst index 0910001783..c968d72945 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -152,11 +152,24 @@ PILReader .. autoclass:: PILReader :members: +Whole slide image reader +------------------------ + +BaseWSIReader +~~~~~~~~~~~~~ +.. autoclass:: BaseWSIReader + :members: + WSIReader ~~~~~~~~~ .. autoclass:: WSIReader :members: +CuCIMWSIReader +~~~~~~~~~~~~~~ +.. autoclass:: CuCIMWSIReader + :members: + Image writer ------------ diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index 71f3214ea4..756223a784 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -16,7 +16,7 @@ import numpy as np from monai.data import Dataset, SmartCacheDataset -from monai.data.image_reader import WSIReader +from monai.data.wsi_reader import WSIReader from monai.utils import ensure_tuple_rep __all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset", "MaskedInferenceWSIDataset"] diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py index 6073bd0cda..e48f2128fe 100644 --- a/monai/apps/pathology/metrics/lesion_froc.py +++ b/monai/apps/pathology/metrics/lesion_froc.py @@ -14,7 +14,7 @@ import numpy as np from monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask -from monai.data.image_reader import WSIReader +from monai.data.wsi_reader import WSIReader from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score from monai.utils import min_version, optional_import diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 19ca29eafa..ca4be87ef6 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -34,7 +34,7 @@ from .folder_layout import FolderLayout from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, @@ -87,3 +87,4 @@ worker_init_fn, zoom_affine, ) +from .wsi_reader import BaseWSIReader, CuCIMWSIReader, WSIReader diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index ca77178e0b..f5d7fdef9d 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -19,8 +19,7 @@ from monai.config import DtypeLike, KeysCollection, PathLike from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format, orientation_ras_lps -from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg +from monai.utils import ensure_tuple, optional_import, require_pkg if TYPE_CHECKING: import itk @@ -39,7 +38,7 @@ CuImage, _ = optional_import("cucim", name="CuImage") TiffFile, _ = optional_import("tifffile", name="TiffFile") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"] class ImageReader(ABC): @@ -714,265 +713,3 @@ def _get_spatial_shape(self, img): img: a PIL Image object loaded from an image file. """ return np.asarray((img.width, img.height)) - - -class WSIReader(ImageReader): - """ - Read whole slide images and extract patches. - - Args: - backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile". - level: the whole slide image level at which the image is extracted. (default=0) - This is overridden if the level argument is provided in `get_data`. - kwargs: additional args for backend reading API in `read()`, more details in `cuCIM`, `TiffFile`, `OpenSlide`: - https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. - https://github.com/cgohlke/tifffile. - https://openslide.org/api/python/#openslide.OpenSlide. - - Note: - While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images - without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory - before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for - patch extraction. - - """ - - def __init__(self, backend: str = "OpenSlide", level: int = 0, **kwargs): - super().__init__() - self.backend = backend.lower() - func = require_pkg(self.backend)(self._set_reader) - self.wsi_reader = func(self.backend) - self.level = level - self.kwargs = kwargs - - @staticmethod - def _set_reader(backend: str): - if backend == "openslide": - return OpenSlide - if backend == "cucim": - return CuImage - if backend == "tifffile": - return TiffFile - raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") - - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: - """ - Verify whether the specified file or files format is supported by WSI reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - """ - return is_supported_format(filename, ["tif", "tiff"]) - - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): - """ - Read image data from given file or list of files. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for backend reading API in `read()`, will override `self.kwargs` for existing keys. - more details in `cuCIM`, `TiffFile`, `OpenSlide`: - https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. - https://github.com/cgohlke/tifffile. - https://openslide.org/api/python/#openslide.OpenSlide. - - Returns: - image object or list of image objects - - """ - img_: List = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = self.wsi_reader(name, **kwargs_) - if self.backend == "openslide": - img.shape = (img.dimensions[1], img.dimensions[0], 3) - img_.append(img) - - return img_ if len(filenames) > 1 else img_[0] - - def get_data( - self, - img, - location: Tuple[int, int] = (0, 0), - size: Optional[Tuple[int, int]] = None, - level: Optional[int] = None, - dtype: DtypeLike = np.uint8, - grid_shape: Tuple[int, int] = (1, 1), - patch_size: Optional[Union[int, Tuple[int, int]]] = None, - ): - """ - Extract regions as numpy array from WSI image and return them. - - Args: - img: a WSIReader image object loaded from a file, or list of CuImage objects - location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame, - or list of tuples (default=(0, 0)) - size: (height, width) tuple giving the region size, or list of tuples (default to full image size) - This is the size of image at the given level (`level`) - level: the level number, or list of level numbers (default=0) - dtype: the data type of output image - grid_shape: (row, columns) tuple define a grid to extract patches on that - patch_size: (height, width) the size of extracted patches at the given level - """ - # Verify inputs - if level is None: - level = self.level - max_level = self._get_max_level(img) - if level > max_level: - raise ValueError(f"The maximum level of this image is {max_level} while level={level} is requested)!") - - # Extract a region or the entire image - region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) - - # Add necessary metadata - metadata: Dict = {} - metadata["spatial_shape"] = np.asarray(region.shape[:-1]) - metadata["original_channel_dim"] = -1 - - # Make it channel first - region = EnsureChannelFirst()(region, metadata) - - # Split into patches - if patch_size is None: - patches = region - else: - tuple_patch_size = ensure_tuple_rep(patch_size, 2) - patches = self._extract_patches( - region, patch_size=tuple_patch_size, grid_shape=grid_shape, dtype=dtype # type: ignore - ) - - return patches, metadata - - def _get_max_level(self, img_obj): - """ - Return the maximum number of levels in the whole slide image - Args: - img: the whole slide image object - - """ - if self.backend == "openslide": - return img_obj.level_count - 1 - if self.backend == "cucim": - return img_obj.resolutions["level_count"] - 1 - if self.backend == "tifffile": - return len(img_obj.pages) - 1 - - def _get_image_size(self, img, size, level, location): - """ - Calculate the maximum region size for the given level and starting location (if size is None). - Note that region size in OpenSlide and cuCIM are WxH (but the final image output would be HxW) - """ - if size is not None: - return size[::-1] - - max_size = [] - downsampling_factor = [] - if self.backend == "openslide": - downsampling_factor = img.level_downsamples[level] - max_size = img.level_dimensions[level] - elif self.backend == "cucim": - downsampling_factor = img.resolutions["level_downsamples"][level] - max_size = img.resolutions["level_dimensions"][level] - - # subtract the top left corner of the patch (at given level) from maximum size - location_at_level = (round(location[1] / downsampling_factor), round(location[0] / downsampling_factor)) - size = [max_size[i] - location_at_level[i] for i in range(len(max_size))] - - return size - - def _extract_region( - self, - img_obj, - size: Optional[Tuple[int, int]], - location: Tuple[int, int] = (0, 0), - level: int = 0, - dtype: DtypeLike = np.uint8, - ): - if self.backend == "tifffile": - # Read the entire image - if size is not None: - raise ValueError( - f"TiffFile backend reads the entire image only, so size '{size}'' should not be provided!", - "For more flexibility or extracting regions, please use cuCIM or OpenSlide backend.", - ) - if location != (0, 0): - raise ValueError( - f"TiffFile backend reads the entire image only, so location '{location}' should not be provided!", - "For more flexibility and extracting regions, please use cuCIM or OpenSlide backend.", - ) - region = img_obj.asarray(level=level) - else: - # Get region size to be extracted - region_size = self._get_image_size(img_obj, size, level, location) - # reverse the order of location's dimensions to become WxH (for cuCIM and OpenSlide) - region_location = location[::-1] - # Extract a region (or the entire image) - region = img_obj.read_region(location=region_location, size=region_size, level=level) - - region = self.convert_to_rgb_array(region, dtype) - return region - - def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): - """Convert to RGB mode and numpy array""" - if self.backend == "openslide": - # convert to RGB - raw_region = raw_region.convert("RGB") - - # convert to numpy (if not already in numpy) - raw_region = np.asarray(raw_region, dtype=dtype) - - # check if the image has three dimensions (2D + color) - if raw_region.ndim != 3: - raise ValueError( - f"The input image dimension should be 3 but {raw_region.ndim} is given. " - "`WSIReader` is designed to work only with 2D colored images." - ) - - # check if the color channel is 3 (RGB) or 4 (RGBA) - if raw_region.shape[-1] not in [3, 4]: - raise ValueError( - f"There should be three or four color channels but {raw_region.shape[-1]} is given. " - "`WSIReader` is designed to work only with 2D colored images." - ) - - # remove alpha channel if exist (RGBA) - if raw_region.shape[-1] > 3: - raw_region = raw_region[..., :3] - - return raw_region - - def _extract_patches( - self, - region: np.ndarray, - grid_shape: Tuple[int, int] = (1, 1), - patch_size: Optional[Tuple[int, int]] = None, - dtype: DtypeLike = np.uint8, - ): - if patch_size is None and grid_shape == (1, 1): - return region - - n_patches = grid_shape[0] * grid_shape[1] - region_size = region.shape[1:] - - if patch_size is None: - patch_size = (region_size[0] // grid_shape[0], region_size[1] // grid_shape[1]) - - # split the region into patches on the grid and center crop them to patch size - flat_patch_grid = np.zeros((n_patches, 3, patch_size[0], patch_size[1]), dtype=dtype) - start_points = [ - np.round(region_size[i] * (0.5 + np.arange(grid_shape[i])) / grid_shape[i] - patch_size[i] / 2).astype(int) - for i in range(2) - ] - idx = 0 - for y_start in start_points[1]: - for x_start in start_points[0]: - x_end = x_start + patch_size[0] - y_end = y_start + patch_size[1] - flat_patch_grid[idx] = region[:, x_start:x_end, y_start:y_end] - idx += 1 - - return flat_patch_grid diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py new file mode 100644 index 0000000000..4899fb8830 --- /dev/null +++ b/monai/data/wsi_reader.py @@ -0,0 +1,420 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np + +from monai.config import DtypeLike, PathLike +from monai.data.image_reader import ImageReader, _stack_images +from monai.data.utils import is_supported_format +from monai.transforms.utility.array import EnsureChannelFirst +from monai.utils import ensure_tuple, optional_import, require_pkg + +CuImage, _ = optional_import("cucim", name="CuImage") + +__all__ = ["BaseWSIReader", "WSIReader", "CuCIMWSIReader"] + + +class BaseWSIReader(ImageReader): + """ + An abstract class that defines APIs to load patches from whole slide image files. + + Typical usage of a concrete implementation of this class is: + + .. code-block:: python + + image_reader = MyWSIReader() + wsi = image_reader.read(, **kwargs) + img_data, meta_data = image_reader.get_data(wsi) + + - The `read` call converts an image filename into whole slide image object, + - The `get_data` call fetches the image data, as well as meta data. + + The following methods needs to be implemented for any concrete implementation of this class: + + - `read` reads a whole slide image object from a given file + - `get_size` returns the size of the whole slide image of a given wsi object at a given level. + - `get_level_count` returns the number of levels in the whole slide image + - `get_patch` extracts and returns a patch image form the whole slide image + - `get_metadata` extracts and returns metadata for a whole slide image and a specific patch. + + + """ + + supported_suffixes: List[str] = [] + + def __init__(self, level: int, **kwargs): + super().__init__() + self.level = level + self.kwargs = kwargs + self.metadata: Dict[Any, Any] = {} + + @abstractmethod + def get_size(self, wsi, level: int) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_level_count(self, wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def get_data( + self, + wsi, + location: Tuple[int, int] = (0, 0), + size: Optional[Tuple[int, int]] = None, + level: Optional[int] = None, + dtype: DtypeLike = np.uint8, + mode: str = "RGB", + ) -> Tuple[np.ndarray, Dict]: + """ + Verifies inputs, extracts patches from WSI image and generates metadata, and return them. + + Args: + wsi: a whole slide image object loaded from a file or a list of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + Returns: + a tuples, where the first element is an image patch [CxHxW] or stack of patches, + and second element is a dictionary of metadata + """ + patch_list: List = [] + metadata = {} + # CuImage object is iterable, so ensure_tuple won't work on single object + if not isinstance(wsi, List): + wsi = [wsi] + for each_wsi in ensure_tuple(wsi): + # Verify magnification level + if level is None: + level = self.level + max_level = self.get_level_count(each_wsi) - 1 + if level > max_level: + raise ValueError(f"The maximum level of this image is {max_level} while level={level} is requested)!") + + # Verify location + if location is None: + location = (0, 0) + wsi_size = self.get_size(each_wsi, level) + if location[0] > wsi_size[0] or location[1] > wsi_size[1]: + raise ValueError(f"Location is outside of the image: location={location}, image size={wsi_size}") + + # Verify size + if size is None: + if location != (0, 0): + raise ValueError("Patch size should be defined to exctract patches.") + size = self.get_size(each_wsi, level) + else: + if size[0] <= 0 or size[1] <= 0: + raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}") + + # Extract a patch or the entire image + patch = self.get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + + # check if the image has three dimensions (2D + color) + if patch.ndim != 3: + raise ValueError( + f"The image dimension should be 3 but has {patch.ndim}. " + "`WSIReader` is designed to work only with 2D images with color channel." + ) + + # Create a list of patches + patch_list.append(patch) + + # Set patch-related metadata + each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level) + metadata.update(each_meta) + + return _stack_images(patch_list, metadata), metadata + + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + """ + Verify whether the specified file or files format is supported by WSI reader. + + The list of supported suffixes are read from `self.supported_suffixes`. + + Args: + filename: filename or a list of filenames to read. + + """ + return is_supported_format(filename, self.supported_suffixes) + + +class WSIReader(BaseWSIReader): + """ + Read whole slide images and extract patches using different backend libraries + + Args: + backend: the name of backend whole slide image reader library, the default is cuCIM. + level: the level at which patches are extracted. + kwargs: additional arguments to be passed to the backend library + + """ + + def __init__(self, backend="cucim", level: int = 0, **kwargs): + super().__init__(level, **kwargs) + self.backend = backend.lower() + # Any new backend can be added below + if self.backend == "cucim": + self.reader = CuCIMWSIReader(level=level, **kwargs) + else: + raise ValueError("The supported backends are: cucim") + self.supported_suffixes = self.reader.supported_suffixes + + def get_level_count(self, wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + return self.reader.get_level_count(wsi) + + def get_size(self, wsi, level) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return self.reader.get_size(wsi, level) + + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + return self.reader.get_metadata(patch=patch, size=size, location=location, level=level) + + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + return self.reader.get_patch(wsi=wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + """ + Read whole slide image objects from given file or list of files. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for the reader module (overrides `self.kwargs` for existing keys). + + Returns: + whole slide image object or list of such objects + + """ + return self.reader.read(data=data, **kwargs) + + +@require_pkg(pkg_name="cucim") +class CuCIMWSIReader(BaseWSIReader): + """ + Read whole slide images and extract patches without loading the whole slide image into the memory. + + Args: + level: the whole slide image level at which the image is extracted. (default=0) + This is overridden if the level argument is provided in `get_data`. + kwargs: additional args for `cucim.CuImage` module: + https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h + + """ + + supported_suffixes = ["tif", "tiff", "svs"] + + def __init__(self, level: int = 0, **kwargs): + super().__init__(level, **kwargs) + + @staticmethod + def get_level_count(wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + return wsi.resolutions["level_count"] # type: ignore + + @staticmethod + def get_size(wsi, level) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) + + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + metadata: Dict = { + "backend": "cucim", + "spatial_shape": np.asarray(patch.shape[1:]), + "original_channel_dim": 0, + "location": location, + "size": size, + "level": level, + } + return metadata + + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + """ + Read whole slide image objects from given file or list of files. + + Args: + data: file name or a list of file names to read. + kwargs: additional args that overrides `self.kwargs` for existing keys. + For more details look at https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h + + Returns: + whole slide image object or list of such objects + + """ + wsi_list: List = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for filename in filenames: + wsi = CuImage(filename, **kwargs_) + wsi_list.append(wsi) + + return wsi_list if len(filenames) > 1 else wsi_list[0] + + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + # Extract a patch or the entire image + # (reverse the order of location and size to become WxH for cuCIM) + patch: np.ndarray = wsi.read_region(location=location[::-1], size=size[::-1], level=level) + + # Convert to numpy + patch = np.asarray(patch, dtype=dtype) + + # Make it channel first + patch = EnsureChannelFirst()(patch, {"original_channel_dim": -1}) # type: ignore + + # Check if the color channel is 3 (RGB) or 4 (RGBA) + if mode == "RGBA" and patch.shape[0] != 4: + raise ValueError( + f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}." + ) + + if mode in "RGB": + if patch.shape[0] not in [3, 4]: + raise ValueError( + f"The image is expected to have three or four color channels in '{mode}' mode but has {patch.shape[0]}. " + ) + patch = patch[:3] + + return patch diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 6ee02143b8..7b288f6040 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import DataLoader, Dataset -from monai.data.image_reader import WSIReader +from monai.data.wsi_reader import WSIReader from monai.transforms import Compose, LoadImaged, ToTensord from monai.utils import first, optional_import from monai.utils.enums import PostFix @@ -57,29 +57,17 @@ ] TEST_CASE_3 = [ - FILE_PATH, - {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 2}, - np.array( + [FILE_PATH, FILE_PATH], + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.concatenate( [ - [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], - [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], - ] + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + ], + axis=0, ), ] -TEST_CASE_4 = [ - FILE_PATH, - {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 1}, - np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), -] - -TEST_CASE_5 = [ - FILE_PATH, - {"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)}, - np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), -] - - TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW @@ -138,7 +126,7 @@ def test_read_whole_image(self, file_path, level, expected_shape): img = reader.get_data(img_obj)[0] self.assertTupleEqual(img.shape, expected_shape) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_read_region(self, file_path, patch_info, expected_img): kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} reader = WSIReader(self.backend, **kwargs) @@ -155,17 +143,22 @@ def test_read_region(self, file_path, patch_info, expected_img): self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) - @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) - def test_read_patches(self, file_path, patch_info, expected_img): - reader = WSIReader(self.backend) - with reader.read(file_path) as img_obj: - if self.backend == "tifffile": - with self.assertRaises(ValueError): - reader.get_data(img_obj, **patch_info)[0] - else: - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + @parameterized.expand([TEST_CASE_3]) + def test_read_region_multi_wsi(self, file_path, patch_info, expected_img): + kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} + reader = WSIReader(self.backend, **kwargs) + img_obj = reader.read(file_path, **kwargs) + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + # Read twice to check multiple calls + img = reader.get_data(img_obj, **patch_info)[0] + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + self.assertIsNone(assert_array_equal(img, img2)) + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.") @@ -221,19 +214,5 @@ def setUpClass(cls): cls.backend = "cucim" -@skipUnless(has_osl, "Requires OpenSlide") -class TestOpenSlide(WSIReaderTests.Tests): - @classmethod - def setUpClass(cls): - cls.backend = "openslide" - - -@skipUnless(has_tiff, "Requires TiffFile") -class TestTiffFile(WSIReaderTests.Tests): - @classmethod - def setUpClass(cls): - cls.backend = "tifffile" - - if __name__ == "__main__": unittest.main()