From 78340a91543b315dc5671b2c686fab66fb612a66 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 19 Apr 2022 10:28:41 -0400 Subject: [PATCH 01/14] Redesign whole slide image reading (#4107) * Redesign BaseWSIReader, WSIReader, CuCIMWSIReader Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for WSIReader Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add image mode for output validation Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update references to new WSIReader Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove legacy WSIReader Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update unittests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * sort imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Clean up imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update docstrings Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update docs and docstrings Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix a typo Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove redundant checking Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update read and other methods Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update wsireader to support multi image and update docstrings Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Make workaround for CuImage objects Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for multi image reading Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update a note about cucim Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update type hints and docstrings Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/data.rst | 13 + monai/apps/pathology/data/datasets.py | 2 +- monai/apps/pathology/metrics/lesion_froc.py | 2 +- monai/data/__init__.py | 3 +- monai/data/image_reader.py | 267 +------------ monai/data/wsi_reader.py | 420 ++++++++++++++++++++ tests/test_wsireader.py | 71 ++-- 7 files changed, 464 insertions(+), 314 deletions(-) create mode 100644 monai/data/wsi_reader.py diff --git a/docs/source/data.rst b/docs/source/data.rst index 0910001783..c968d72945 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -152,11 +152,24 @@ PILReader .. autoclass:: PILReader :members: +Whole slide image reader +------------------------ + +BaseWSIReader +~~~~~~~~~~~~~ +.. autoclass:: BaseWSIReader + :members: + WSIReader ~~~~~~~~~ .. autoclass:: WSIReader :members: +CuCIMWSIReader +~~~~~~~~~~~~~~ +.. autoclass:: CuCIMWSIReader + :members: + Image writer ------------ diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index 71f3214ea4..756223a784 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -16,7 +16,7 @@ import numpy as np from monai.data import Dataset, SmartCacheDataset -from monai.data.image_reader import WSIReader +from monai.data.wsi_reader import WSIReader from monai.utils import ensure_tuple_rep __all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset", "MaskedInferenceWSIDataset"] diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py index 6073bd0cda..e48f2128fe 100644 --- a/monai/apps/pathology/metrics/lesion_froc.py +++ b/monai/apps/pathology/metrics/lesion_froc.py @@ -14,7 +14,7 @@ import numpy as np from monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask -from monai.data.image_reader import WSIReader +from monai.data.wsi_reader import WSIReader from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score from monai.utils import min_version, optional_import diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 19ca29eafa..ca4be87ef6 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -34,7 +34,7 @@ from .folder_layout import FolderLayout from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, @@ -87,3 +87,4 @@ worker_init_fn, zoom_affine, ) +from .wsi_reader import BaseWSIReader, CuCIMWSIReader, WSIReader diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index ca77178e0b..f5d7fdef9d 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -19,8 +19,7 @@ from monai.config import DtypeLike, KeysCollection, PathLike from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format, orientation_ras_lps -from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg +from monai.utils import ensure_tuple, optional_import, require_pkg if TYPE_CHECKING: import itk @@ -39,7 +38,7 @@ CuImage, _ = optional_import("cucim", name="CuImage") TiffFile, _ = optional_import("tifffile", name="TiffFile") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"] class ImageReader(ABC): @@ -714,265 +713,3 @@ def _get_spatial_shape(self, img): img: a PIL Image object loaded from an image file. """ return np.asarray((img.width, img.height)) - - -class WSIReader(ImageReader): - """ - Read whole slide images and extract patches. - - Args: - backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile". - level: the whole slide image level at which the image is extracted. (default=0) - This is overridden if the level argument is provided in `get_data`. - kwargs: additional args for backend reading API in `read()`, more details in `cuCIM`, `TiffFile`, `OpenSlide`: - https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. - https://github.com/cgohlke/tifffile. - https://openslide.org/api/python/#openslide.OpenSlide. - - Note: - While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images - without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory - before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for - patch extraction. - - """ - - def __init__(self, backend: str = "OpenSlide", level: int = 0, **kwargs): - super().__init__() - self.backend = backend.lower() - func = require_pkg(self.backend)(self._set_reader) - self.wsi_reader = func(self.backend) - self.level = level - self.kwargs = kwargs - - @staticmethod - def _set_reader(backend: str): - if backend == "openslide": - return OpenSlide - if backend == "cucim": - return CuImage - if backend == "tifffile": - return TiffFile - raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") - - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: - """ - Verify whether the specified file or files format is supported by WSI reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - """ - return is_supported_format(filename, ["tif", "tiff"]) - - def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): - """ - Read image data from given file or list of files. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for backend reading API in `read()`, will override `self.kwargs` for existing keys. - more details in `cuCIM`, `TiffFile`, `OpenSlide`: - https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. - https://github.com/cgohlke/tifffile. - https://openslide.org/api/python/#openslide.OpenSlide. - - Returns: - image object or list of image objects - - """ - img_: List = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = self.wsi_reader(name, **kwargs_) - if self.backend == "openslide": - img.shape = (img.dimensions[1], img.dimensions[0], 3) - img_.append(img) - - return img_ if len(filenames) > 1 else img_[0] - - def get_data( - self, - img, - location: Tuple[int, int] = (0, 0), - size: Optional[Tuple[int, int]] = None, - level: Optional[int] = None, - dtype: DtypeLike = np.uint8, - grid_shape: Tuple[int, int] = (1, 1), - patch_size: Optional[Union[int, Tuple[int, int]]] = None, - ): - """ - Extract regions as numpy array from WSI image and return them. - - Args: - img: a WSIReader image object loaded from a file, or list of CuImage objects - location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame, - or list of tuples (default=(0, 0)) - size: (height, width) tuple giving the region size, or list of tuples (default to full image size) - This is the size of image at the given level (`level`) - level: the level number, or list of level numbers (default=0) - dtype: the data type of output image - grid_shape: (row, columns) tuple define a grid to extract patches on that - patch_size: (height, width) the size of extracted patches at the given level - """ - # Verify inputs - if level is None: - level = self.level - max_level = self._get_max_level(img) - if level > max_level: - raise ValueError(f"The maximum level of this image is {max_level} while level={level} is requested)!") - - # Extract a region or the entire image - region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) - - # Add necessary metadata - metadata: Dict = {} - metadata["spatial_shape"] = np.asarray(region.shape[:-1]) - metadata["original_channel_dim"] = -1 - - # Make it channel first - region = EnsureChannelFirst()(region, metadata) - - # Split into patches - if patch_size is None: - patches = region - else: - tuple_patch_size = ensure_tuple_rep(patch_size, 2) - patches = self._extract_patches( - region, patch_size=tuple_patch_size, grid_shape=grid_shape, dtype=dtype # type: ignore - ) - - return patches, metadata - - def _get_max_level(self, img_obj): - """ - Return the maximum number of levels in the whole slide image - Args: - img: the whole slide image object - - """ - if self.backend == "openslide": - return img_obj.level_count - 1 - if self.backend == "cucim": - return img_obj.resolutions["level_count"] - 1 - if self.backend == "tifffile": - return len(img_obj.pages) - 1 - - def _get_image_size(self, img, size, level, location): - """ - Calculate the maximum region size for the given level and starting location (if size is None). - Note that region size in OpenSlide and cuCIM are WxH (but the final image output would be HxW) - """ - if size is not None: - return size[::-1] - - max_size = [] - downsampling_factor = [] - if self.backend == "openslide": - downsampling_factor = img.level_downsamples[level] - max_size = img.level_dimensions[level] - elif self.backend == "cucim": - downsampling_factor = img.resolutions["level_downsamples"][level] - max_size = img.resolutions["level_dimensions"][level] - - # subtract the top left corner of the patch (at given level) from maximum size - location_at_level = (round(location[1] / downsampling_factor), round(location[0] / downsampling_factor)) - size = [max_size[i] - location_at_level[i] for i in range(len(max_size))] - - return size - - def _extract_region( - self, - img_obj, - size: Optional[Tuple[int, int]], - location: Tuple[int, int] = (0, 0), - level: int = 0, - dtype: DtypeLike = np.uint8, - ): - if self.backend == "tifffile": - # Read the entire image - if size is not None: - raise ValueError( - f"TiffFile backend reads the entire image only, so size '{size}'' should not be provided!", - "For more flexibility or extracting regions, please use cuCIM or OpenSlide backend.", - ) - if location != (0, 0): - raise ValueError( - f"TiffFile backend reads the entire image only, so location '{location}' should not be provided!", - "For more flexibility and extracting regions, please use cuCIM or OpenSlide backend.", - ) - region = img_obj.asarray(level=level) - else: - # Get region size to be extracted - region_size = self._get_image_size(img_obj, size, level, location) - # reverse the order of location's dimensions to become WxH (for cuCIM and OpenSlide) - region_location = location[::-1] - # Extract a region (or the entire image) - region = img_obj.read_region(location=region_location, size=region_size, level=level) - - region = self.convert_to_rgb_array(region, dtype) - return region - - def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): - """Convert to RGB mode and numpy array""" - if self.backend == "openslide": - # convert to RGB - raw_region = raw_region.convert("RGB") - - # convert to numpy (if not already in numpy) - raw_region = np.asarray(raw_region, dtype=dtype) - - # check if the image has three dimensions (2D + color) - if raw_region.ndim != 3: - raise ValueError( - f"The input image dimension should be 3 but {raw_region.ndim} is given. " - "`WSIReader` is designed to work only with 2D colored images." - ) - - # check if the color channel is 3 (RGB) or 4 (RGBA) - if raw_region.shape[-1] not in [3, 4]: - raise ValueError( - f"There should be three or four color channels but {raw_region.shape[-1]} is given. " - "`WSIReader` is designed to work only with 2D colored images." - ) - - # remove alpha channel if exist (RGBA) - if raw_region.shape[-1] > 3: - raw_region = raw_region[..., :3] - - return raw_region - - def _extract_patches( - self, - region: np.ndarray, - grid_shape: Tuple[int, int] = (1, 1), - patch_size: Optional[Tuple[int, int]] = None, - dtype: DtypeLike = np.uint8, - ): - if patch_size is None and grid_shape == (1, 1): - return region - - n_patches = grid_shape[0] * grid_shape[1] - region_size = region.shape[1:] - - if patch_size is None: - patch_size = (region_size[0] // grid_shape[0], region_size[1] // grid_shape[1]) - - # split the region into patches on the grid and center crop them to patch size - flat_patch_grid = np.zeros((n_patches, 3, patch_size[0], patch_size[1]), dtype=dtype) - start_points = [ - np.round(region_size[i] * (0.5 + np.arange(grid_shape[i])) / grid_shape[i] - patch_size[i] / 2).astype(int) - for i in range(2) - ] - idx = 0 - for y_start in start_points[1]: - for x_start in start_points[0]: - x_end = x_start + patch_size[0] - y_end = y_start + patch_size[1] - flat_patch_grid[idx] = region[:, x_start:x_end, y_start:y_end] - idx += 1 - - return flat_patch_grid diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py new file mode 100644 index 0000000000..4899fb8830 --- /dev/null +++ b/monai/data/wsi_reader.py @@ -0,0 +1,420 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np + +from monai.config import DtypeLike, PathLike +from monai.data.image_reader import ImageReader, _stack_images +from monai.data.utils import is_supported_format +from monai.transforms.utility.array import EnsureChannelFirst +from monai.utils import ensure_tuple, optional_import, require_pkg + +CuImage, _ = optional_import("cucim", name="CuImage") + +__all__ = ["BaseWSIReader", "WSIReader", "CuCIMWSIReader"] + + +class BaseWSIReader(ImageReader): + """ + An abstract class that defines APIs to load patches from whole slide image files. + + Typical usage of a concrete implementation of this class is: + + .. code-block:: python + + image_reader = MyWSIReader() + wsi = image_reader.read(, **kwargs) + img_data, meta_data = image_reader.get_data(wsi) + + - The `read` call converts an image filename into whole slide image object, + - The `get_data` call fetches the image data, as well as meta data. + + The following methods needs to be implemented for any concrete implementation of this class: + + - `read` reads a whole slide image object from a given file + - `get_size` returns the size of the whole slide image of a given wsi object at a given level. + - `get_level_count` returns the number of levels in the whole slide image + - `get_patch` extracts and returns a patch image form the whole slide image + - `get_metadata` extracts and returns metadata for a whole slide image and a specific patch. + + + """ + + supported_suffixes: List[str] = [] + + def __init__(self, level: int, **kwargs): + super().__init__() + self.level = level + self.kwargs = kwargs + self.metadata: Dict[Any, Any] = {} + + @abstractmethod + def get_size(self, wsi, level: int) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_level_count(self, wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def get_data( + self, + wsi, + location: Tuple[int, int] = (0, 0), + size: Optional[Tuple[int, int]] = None, + level: Optional[int] = None, + dtype: DtypeLike = np.uint8, + mode: str = "RGB", + ) -> Tuple[np.ndarray, Dict]: + """ + Verifies inputs, extracts patches from WSI image and generates metadata, and return them. + + Args: + wsi: a whole slide image object loaded from a file or a list of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + Returns: + a tuples, where the first element is an image patch [CxHxW] or stack of patches, + and second element is a dictionary of metadata + """ + patch_list: List = [] + metadata = {} + # CuImage object is iterable, so ensure_tuple won't work on single object + if not isinstance(wsi, List): + wsi = [wsi] + for each_wsi in ensure_tuple(wsi): + # Verify magnification level + if level is None: + level = self.level + max_level = self.get_level_count(each_wsi) - 1 + if level > max_level: + raise ValueError(f"The maximum level of this image is {max_level} while level={level} is requested)!") + + # Verify location + if location is None: + location = (0, 0) + wsi_size = self.get_size(each_wsi, level) + if location[0] > wsi_size[0] or location[1] > wsi_size[1]: + raise ValueError(f"Location is outside of the image: location={location}, image size={wsi_size}") + + # Verify size + if size is None: + if location != (0, 0): + raise ValueError("Patch size should be defined to exctract patches.") + size = self.get_size(each_wsi, level) + else: + if size[0] <= 0 or size[1] <= 0: + raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}") + + # Extract a patch or the entire image + patch = self.get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + + # check if the image has three dimensions (2D + color) + if patch.ndim != 3: + raise ValueError( + f"The image dimension should be 3 but has {patch.ndim}. " + "`WSIReader` is designed to work only with 2D images with color channel." + ) + + # Create a list of patches + patch_list.append(patch) + + # Set patch-related metadata + each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level) + metadata.update(each_meta) + + return _stack_images(patch_list, metadata), metadata + + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: + """ + Verify whether the specified file or files format is supported by WSI reader. + + The list of supported suffixes are read from `self.supported_suffixes`. + + Args: + filename: filename or a list of filenames to read. + + """ + return is_supported_format(filename, self.supported_suffixes) + + +class WSIReader(BaseWSIReader): + """ + Read whole slide images and extract patches using different backend libraries + + Args: + backend: the name of backend whole slide image reader library, the default is cuCIM. + level: the level at which patches are extracted. + kwargs: additional arguments to be passed to the backend library + + """ + + def __init__(self, backend="cucim", level: int = 0, **kwargs): + super().__init__(level, **kwargs) + self.backend = backend.lower() + # Any new backend can be added below + if self.backend == "cucim": + self.reader = CuCIMWSIReader(level=level, **kwargs) + else: + raise ValueError("The supported backends are: cucim") + self.supported_suffixes = self.reader.supported_suffixes + + def get_level_count(self, wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + return self.reader.get_level_count(wsi) + + def get_size(self, wsi, level) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return self.reader.get_size(wsi, level) + + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + return self.reader.get_metadata(patch=patch, size=size, location=location, level=level) + + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + return self.reader.get_patch(wsi=wsi, location=location, size=size, level=level, dtype=dtype, mode=mode) + + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + """ + Read whole slide image objects from given file or list of files. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for the reader module (overrides `self.kwargs` for existing keys). + + Returns: + whole slide image object or list of such objects + + """ + return self.reader.read(data=data, **kwargs) + + +@require_pkg(pkg_name="cucim") +class CuCIMWSIReader(BaseWSIReader): + """ + Read whole slide images and extract patches without loading the whole slide image into the memory. + + Args: + level: the whole slide image level at which the image is extracted. (default=0) + This is overridden if the level argument is provided in `get_data`. + kwargs: additional args for `cucim.CuImage` module: + https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h + + """ + + supported_suffixes = ["tif", "tiff", "svs"] + + def __init__(self, level: int = 0, **kwargs): + super().__init__(level, **kwargs) + + @staticmethod + def get_level_count(wsi) -> int: + """ + Returns the number of levels in the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file + + """ + return wsi.resolutions["level_count"] # type: ignore + + @staticmethod + def get_size(wsi, level) -> Tuple[int, int]: + """ + Returns the size of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) + + def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict: + """ + Extracts and returns metadata form the whole slide image. + + Args: + patch: extracted patch from whole slide image + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + + """ + metadata: Dict = { + "backend": "cucim", + "spatial_shape": np.asarray(patch.shape[1:]), + "original_channel_dim": 0, + "location": location, + "size": size, + "level": level, + } + return metadata + + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): + """ + Read whole slide image objects from given file or list of files. + + Args: + data: file name or a list of file names to read. + kwargs: additional args that overrides `self.kwargs` for existing keys. + For more details look at https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h + + Returns: + whole slide image object or list of such objects + + """ + wsi_list: List = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for filename in filenames: + wsi = CuImage(filename, **kwargs_) + wsi_list.append(wsi) + + return wsi_list if len(filenames) > 1 else wsi_list[0] + + def get_patch( + self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str + ) -> np.ndarray: + """ + Extracts and returns a patch image form the whole slide image. + + Args: + wsi: a whole slide image object loaded from a file or a lis of such objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). + size: (height, width) tuple giving the patch size at the given level (`level`). + If None, it is set to the full image size at the given level. + level: the level number. Defaults to 0 + dtype: the data type of output image + mode: the output image mode, 'RGB' or 'RGBA' + + """ + # Extract a patch or the entire image + # (reverse the order of location and size to become WxH for cuCIM) + patch: np.ndarray = wsi.read_region(location=location[::-1], size=size[::-1], level=level) + + # Convert to numpy + patch = np.asarray(patch, dtype=dtype) + + # Make it channel first + patch = EnsureChannelFirst()(patch, {"original_channel_dim": -1}) # type: ignore + + # Check if the color channel is 3 (RGB) or 4 (RGBA) + if mode == "RGBA" and patch.shape[0] != 4: + raise ValueError( + f"The image is expected to have four color channels in '{mode}' mode but has {patch.shape[0]}." + ) + + if mode in "RGB": + if patch.shape[0] not in [3, 4]: + raise ValueError( + f"The image is expected to have three or four color channels in '{mode}' mode but has {patch.shape[0]}. " + ) + patch = patch[:3] + + return patch diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 6ee02143b8..7b288f6040 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import DataLoader, Dataset -from monai.data.image_reader import WSIReader +from monai.data.wsi_reader import WSIReader from monai.transforms import Compose, LoadImaged, ToTensord from monai.utils import first, optional_import from monai.utils.enums import PostFix @@ -57,29 +57,17 @@ ] TEST_CASE_3 = [ - FILE_PATH, - {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 2}, - np.array( + [FILE_PATH, FILE_PATH], + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.concatenate( [ - [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], - [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], - ] + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), + ], + axis=0, ), ] -TEST_CASE_4 = [ - FILE_PATH, - {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 1}, - np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), -] - -TEST_CASE_5 = [ - FILE_PATH, - {"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)}, - np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), -] - - TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW @@ -138,7 +126,7 @@ def test_read_whole_image(self, file_path, level, expected_shape): img = reader.get_data(img_obj)[0] self.assertTupleEqual(img.shape, expected_shape) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_read_region(self, file_path, patch_info, expected_img): kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} reader = WSIReader(self.backend, **kwargs) @@ -155,17 +143,22 @@ def test_read_region(self, file_path, patch_info, expected_img): self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) - @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) - def test_read_patches(self, file_path, patch_info, expected_img): - reader = WSIReader(self.backend) - with reader.read(file_path) as img_obj: - if self.backend == "tifffile": - with self.assertRaises(ValueError): - reader.get_data(img_obj, **patch_info)[0] - else: - img = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + @parameterized.expand([TEST_CASE_3]) + def test_read_region_multi_wsi(self, file_path, patch_info, expected_img): + kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} + reader = WSIReader(self.backend, **kwargs) + img_obj = reader.read(file_path, **kwargs) + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + # Read twice to check multiple calls + img = reader.get_data(img_obj, **patch_info)[0] + img2 = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, img2.shape) + self.assertIsNone(assert_array_equal(img, img2)) + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.") @@ -221,19 +214,5 @@ def setUpClass(cls): cls.backend = "cucim" -@skipUnless(has_osl, "Requires OpenSlide") -class TestOpenSlide(WSIReaderTests.Tests): - @classmethod - def setUpClass(cls): - cls.backend = "openslide" - - -@skipUnless(has_tiff, "Requires TiffFile") -class TestTiffFile(WSIReaderTests.Tests): - @classmethod - def setUpClass(cls): - cls.backend = "tifffile" - - if __name__ == "__main__": unittest.main() From 14b189ad65837b6932c2c9236d44d1549f2036bc Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 19 Apr 2022 19:04:54 +0000 Subject: [PATCH 02/14] Implement Split transform Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/__init__.py | 1 + monai/transforms/spatial/array.py | 81 +++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 581e368ba0..3620e611f4 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -330,6 +330,7 @@ Rotate90, Spacing, SpatialResample, + Split, Zoom, ) from .spatial.dictionary import ( diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 41d048fc19..75cd19d86e 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -18,6 +18,7 @@ import numpy as np import torch +from numpy.lib.stride_tricks import as_strided from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor @@ -2460,3 +2461,83 @@ def __call__( if not self._do_transform: return img return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) + + +class Split(Transform): + """ + Split the image into patches based on the provided grid in 2D. + + Args: + grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) + size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is None, where the patch size will be inferred from the grid shape. + + Note: This transform currently support only image with two spatial dimensions. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None + ): + # Grid size + self.grid = grid + + # Patch size + if size is None: + self.size = None + else: + self.size = ensure_tuple_rep(size, len(self.grid)) + + def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: + if self.grid == (1, 1) and self.size is None: + if isinstance(image, torch.Tensor): + return torch.stack([image]) + elif isinstance(image, np.ndarray): + return np.stack([image]) # type: ignore + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + + size, steps = self.get_params(image.shape[1:]) + patches: NdarrayOrTensor + if isinstance(image, torch.Tensor): + patches = ( + image.unfold(1, size[0], steps[0]) + .unfold(2, size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + elif isinstance(image, np.ndarray): + x_step, y_step = steps + c_stride, x_stride, y_stride = image.strides + n_channels = image.shape[0] + patches = as_strided( + image, + shape=(*self.grid, n_channels, size[0], size[1]), + strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), + writeable=False, + ) + # flatten the first two dimensions + patches = patches.reshape(np.prod(patches.shape[:2]), *patches.shape[2:]) + # make it a contiguous array + patches = np.ascontiguousarray(patches) + else: + raise ValueError(f"Input type [{type(image)}] is not supported.") + + return patches + + def get_params(self, image_size): + if self.size is None: + size = tuple(image_size[i] // self.grid[i] for i in range(2)) + else: + size = self.size + + steps = tuple( + (image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] + for i in range(2) + ) + + return size, steps + From 0ffd9dcead7cf771c7816bd90032d7f63d96c7b6 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 19 Apr 2022 19:07:26 +0000 Subject: [PATCH 03/14] Add unittests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_split_transform.py | 84 +++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/test_split_transform.py diff --git a/tests/test_split_transform.py b/tests/test_split_transform.py new file mode 100644 index 0000000000..f2229c9ac6 --- /dev/null +++ b/tests/test_split_transform.py @@ -0,0 +1,84 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import Split +from tests.utils import TEST_NDARRAYS, assert_allclose + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [{"grid": (2, 2)}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"grid": (2, 1)}, A, torch.stack([A1, A2])] +TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])] +TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])] +TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])] +TEST_CASE_5 = [{"grid": 1, "size": 4}, A, torch.stack([A])] +TEST_CASE_6 = [{"grid": 2, "size": 2}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"grid": 1}, A, torch.stack([A])] +TEST_CASE_8 = [ + {"grid": (2, 2), "size": 2}, + torch.arange(12).reshape(1, 3, 4).to(torch.float32), + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [{"grid": (2, 2)}, [A, A], [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])]] +TEST_CASE_MC_1 = [{"grid": (2, 1)}, [A] * 5, [torch.stack([A1, A2])] * 5] +TEST_CASE_MC_2 = [{"grid": (1, 2)}, [A1, A2], [torch.stack([A11, A12]), torch.stack([A21, A22])]] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + + +class TestSplit(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, image, expected): + input_image = in_type(image) + splitter = Split(**input_parameters) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) + + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): + splitter = Split(**input_parameters) + for image, expected in zip(img_list, expected_list): + input_image = in_type(image) + output = splitter(input_image) + assert_allclose(output, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() From a6c02095222325695918b56fb479a930126800a9 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 19 Apr 2022 19:18:55 +0000 Subject: [PATCH 04/14] Update formatting Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/__init__.py | 3 +++ monai/transforms/spatial/array.py | 8 ++---- monai/transforms/spatial/dictionary.py | 35 ++++++++++++++++++++++++++ tests/test_split_transform.py | 6 ++--- 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 3620e611f4..9e0fedbf33 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -391,6 +391,9 @@ SpatialResampled, SpatialResampleD, SpatialResampleDict, + Splitd, + SplitD, + SplitDict, Zoomd, ZoomD, ZoomDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 75cd19d86e..8aca978396 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2478,9 +2478,7 @@ class Split(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__( - self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None - ): + def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tuple[int, int]]] = None): # Grid size self.grid = grid @@ -2535,9 +2533,7 @@ def get_params(self, image_size): size = self.size steps = tuple( - (image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] - for i in range(2) + (image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] for i in range(2) ) return size, steps - diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d42a11fd2f..8fa04677d1 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -49,6 +49,7 @@ Rotate90, Spacing, SpatialResample, + Split, Zoom, ) from monai.transforms.transform import MapTransform, RandomizableTransform @@ -2149,6 +2150,39 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d + +class Splitd(MapTransform): + """ + Split the image into patches based on the provided grid in 2D. + + Args: + grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) + size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is None, where the patch size will be inferred from the grid shape. + + Note: This transform currently support only image with two spatial dimensions. + """ + + backend = Split.backend + + def __init__( + self, + keys: KeysCollection, + grid: Tuple[int, int] = (2, 2), + size: Optional[Union[int, Tuple[int, int]]] = None, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + self.splitter = Split(grid=grid, size=size) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.splitter(d[key]) + return d + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2169,3 +2203,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandRotateD = RandRotateDict = RandRotated ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd +SplitD = SplitDict = Splitd diff --git a/tests/test_split_transform.py b/tests/test_split_transform.py index f2229c9ac6..dcb9c6b2a5 100644 --- a/tests/test_split_transform.py +++ b/tests/test_split_transform.py @@ -31,9 +31,9 @@ TEST_CASE_2 = [{"grid": (1, 2)}, A1, torch.stack([A11, A12])] TEST_CASE_3 = [{"grid": (1, 2)}, A2, torch.stack([A21, A22])] TEST_CASE_4 = [{"grid": (1, 1), "size": (2, 2)}, A, torch.stack([A11])] -TEST_CASE_5 = [{"grid": 1, "size": 4}, A, torch.stack([A])] -TEST_CASE_6 = [{"grid": 2, "size": 2}, A, torch.stack([A11, A12, A21, A22])] -TEST_CASE_7 = [{"grid": 1}, A, torch.stack([A])] +TEST_CASE_5 = [{"grid": (1, 1), "size": 4}, A, torch.stack([A])] +TEST_CASE_6 = [{"grid": (2, 2), "size": 2}, A, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"grid": (1, 1)}, A, torch.stack([A])] TEST_CASE_8 = [ {"grid": (2, 2), "size": 2}, torch.arange(12).reshape(1, 3, 4).to(torch.float32), From 651c6bf00e23d44fa1dcd08589447b5f3154760c Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 19 Apr 2022 19:19:25 +0000 Subject: [PATCH 05/14] Implement SplitDict Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 8fa04677d1..96b4732136 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2150,7 +2150,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d - class Splitd(MapTransform): """ Split the image into patches based on the provided grid in 2D. From 7bb2335cb0e4f33f5f26b28a57af0d936670aec5 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 19 Apr 2022 19:19:45 +0000 Subject: [PATCH 06/14] Add unittests for SplitDict Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/test_split_transform_dict.py | 100 +++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/test_split_transform_dict.py diff --git a/tests/test_split_transform_dict.py b/tests/test_split_transform_dict.py new file mode 100644 index 0000000000..32fda4181c --- /dev/null +++ b/tests/test_split_transform_dict.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import SplitDict +from tests.utils import TEST_NDARRAYS, assert_allclose + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [{"keys": "image", "grid": (2, 2)}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_1 = [{"keys": "image", "grid": (2, 1)}, {"image": A}, torch.stack([A1, A2])] +TEST_CASE_2 = [{"keys": "image", "grid": (1, 2)}, {"image": A1}, torch.stack([A11, A12])] +TEST_CASE_3 = [{"keys": "image", "grid": (1, 2)}, {"image": A2}, torch.stack([A21, A22])] +TEST_CASE_4 = [{"keys": "image", "grid": (1, 1), "size": (2, 2)}, {"image": A}, torch.stack([A11])] +TEST_CASE_5 = [{"keys": "image", "grid": (1, 1), "size": 4}, {"image": A}, torch.stack([A])] +TEST_CASE_6 = [{"keys": "image", "grid": (2, 2), "size": 2}, {"image": A}, torch.stack([A11, A12, A21, A22])] +TEST_CASE_7 = [{"keys": "image", "grid": (1, 1)}, {"image": A}, torch.stack([A])] +TEST_CASE_8 = [ + {"keys": "image", "grid": (2, 2), "size": 2}, + {"image": torch.arange(12).reshape(1, 3, 4).to(torch.float32)}, + torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32), +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + +TEST_CASE_MC_0 = [ + {"keys": "image", "grid": (2, 2)}, + [{"image": A}, {"image": A}], + [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], +] +TEST_CASE_MC_1 = [ + {"keys": "image", "grid": (2, 1)}, + [{"image": A}, {"image": A}, {"image": A}], + [torch.stack([A1, A2])] * 3, +] +TEST_CASE_MC_2 = [ + {"keys": "image", "grid": (1, 2)}, + [{"image": A1}, {"image": A2}], + [torch.stack([A11, A12]), torch.stack([A21, A22])], +] + +TEST_MULTIPLE = [] +for p in TEST_NDARRAYS: + TEST_MULTIPLE.append([p, *TEST_CASE_MC_0]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_1]) + TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) + + +class TestSplitDict(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + splitter = SplitDict(**input_parameters) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) + + @parameterized.expand(TEST_MULTIPLE) + def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): + splitter = SplitDict(**input_parameters) + for img_dict, expected in zip(img_list, expected_list): + input_dict = {} + for k, v in img_dict.items(): + input_dict[k] = in_type(v) + output = splitter(input_dict)[input_parameters["keys"]] + assert_allclose(output, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() From dfb2ecd28cc32cf8cdba7154acbd533cb1cc2a13 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 20 Apr 2022 18:03:12 +0000 Subject: [PATCH 07/14] Add docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/transforms.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 676e0274fe..6d56e78257 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -737,6 +737,15 @@ Spatial :members: :special-members: __call__ +`Split` +"""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoom.png + :alt: example of Split +.. autoclass:: Split + :members: + :special-members: __call__ + + Smooth Field ^^^^^^^^^^^^ @@ -1506,6 +1515,15 @@ Spatial (Dict) :members: :special-members: __call__ +`Splitd` +""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoomd.png + :alt: example of Splitd +.. autoclass:: Splitd + :members: + :special-members: __call__ + + `RandRotate90d` """"""""""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandRotate90d.png From 072203d53547d893de78f4a8f6449b20c1c8f89f Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 22 Apr 2022 15:29:23 +0000 Subject: [PATCH 08/14] Remove images from docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/transforms.rst | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6d56e78257..6eb5fbf4dd 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -737,11 +737,9 @@ Spatial :members: :special-members: __call__ -`Split` +`GridSplit` """""""" -.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoom.png - :alt: example of Split -.. autoclass:: Split +.. autoclass:: GridSplit :members: :special-members: __call__ @@ -1515,11 +1513,9 @@ Spatial (Dict) :members: :special-members: __call__ -`Splitd` +`GridSplitd` """"""""" -.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Zoomd.png - :alt: example of Splitd -.. autoclass:: Splitd +.. autoclass:: GridSplitd :members: :special-members: __call__ From f712f8a08ed67802dd6dec6a5cb8677cc9115ffb Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 22 Apr 2022 15:46:29 +0000 Subject: [PATCH 09/14] Address all comments Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/__init__.py | 8 ++--- monai/transforms/spatial/array.py | 31 +++++++++++-------- monai/transforms/spatial/dictionary.py | 15 ++++++--- ..._split_transform.py => test_grid_split.py} | 8 ++--- ..._transform_dict.py => test_grid_splitd.py} | 8 ++--- 5 files changed, 40 insertions(+), 30 deletions(-) rename tests/{test_split_transform.py => test_grid_split.py} (94%) rename tests/{test_split_transform_dict.py => test_grid_splitd.py} (95%) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9e0fedbf33..c2385499b3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -311,6 +311,7 @@ AffineGrid, Flip, GridDistortion, + GridSplit, Orientation, Rand2DElastic, Rand3DElastic, @@ -330,7 +331,6 @@ Rotate90, Spacing, SpatialResample, - Split, Zoom, ) from .spatial.dictionary import ( @@ -343,6 +343,9 @@ GridDistortiond, GridDistortionD, GridDistortionDict, + GridSplitd, + GridSplitD, + GridSplitDict, Orientationd, OrientationD, OrientationDict, @@ -391,9 +394,6 @@ SpatialResampled, SpatialResampleD, SpatialResampleDict, - Splitd, - SplitD, - SplitDict, Zoomd, ZoomD, ZoomDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d100a31a1c..da62e3ffc2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -66,7 +66,7 @@ "Orientation", "Flip", "GridDistortion", - "Resize", + "GridSplit" "Resize", "Rotate", "Zoom", "Rotate90", @@ -2465,7 +2465,7 @@ def __call__( return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode) -class Split(Transform): +class GridSplit(Transform): """ Split the image into patches based on the provided grid in 2D. @@ -2485,10 +2485,7 @@ def __init__(self, grid: Tuple[int, int] = (2, 2), size: Optional[Union[int, Tup self.grid = grid # Patch size - if size is None: - self.size = None - else: - self.size = ensure_tuple_rep(size, len(self.grid)) + self.size = None if size is None else ensure_tuple_rep(size, len(self.grid)) def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: if self.grid == (1, 1) and self.size is None: @@ -2499,7 +2496,7 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: else: raise ValueError(f"Input type [{type(image)}] is not supported.") - size, steps = self.get_params(image.shape[1:]) + size, steps = self._get_params(image.shape[1:]) patches: NdarrayOrTensor if isinstance(image, torch.Tensor): patches = ( @@ -2528,14 +2525,22 @@ def __call__(self, image: NdarrayOrTensor) -> NdarrayOrTensor: return patches - def get_params(self, image_size): - if self.size is None: - size = tuple(image_size[i] // self.grid[i] for i in range(2)) + def _get_params(self, image_size: Union[Sequence[int], np.ndarray]): + """ + Calculate the size and step required for splitting the image + Args: + The size of the input image + """ + if self.size is not None: + # Set the requested size + split_size = self.size else: - size = self.size + # infer each sub-image size from the image size and the grid + split_size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid))) steps = tuple( - (image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] for i in range(2) + (image_size[i] - split_size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i] + for i in range(len(self.grid)) ) - return size, steps + return split_size, steps diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 96b4732136..47fe05700e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -34,6 +34,7 @@ AffineGrid, Flip, GridDistortion, + GridSplit, Orientation, Rand2DElastic, Rand3DElastic, @@ -49,7 +50,6 @@ Rotate90, Spacing, SpatialResample, - Split, Zoom, ) from monai.transforms.transform import MapTransform, RandomizableTransform @@ -130,6 +130,9 @@ "ZoomDict", "RandZoomD", "RandZoomDict", + "GridSplitd", + "GridSplitD", + "GridSplitDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] @@ -2150,20 +2153,22 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d -class Splitd(MapTransform): +class GridSplitd(MapTransform): """ Split the image into patches based on the provided grid in 2D. Args: + keys: keys of the corresponding items to be transformed. grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2) size: a tuple or an integer that defines the output patch sizes. If it's an integer, the value will be repeated for each dimension. The default is None, where the patch size will be inferred from the grid shape. + allow_missing_keys: don't raise exception if key is missing. Note: This transform currently support only image with two spatial dimensions. """ - backend = Split.backend + backend = GridSplit.backend def __init__( self, @@ -2173,7 +2178,7 @@ def __init__( allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) - self.splitter = Split(grid=grid, size=size) + self.splitter = GridSplit(grid=grid, size=size) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -2202,4 +2207,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandRotateD = RandRotateDict = RandRotated ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd -SplitD = SplitDict = Splitd +GridSplitD = GridSplitDict = GridSplitd diff --git a/tests/test_split_transform.py b/tests/test_grid_split.py similarity index 94% rename from tests/test_split_transform.py rename to tests/test_grid_split.py index dcb9c6b2a5..6f0525029d 100644 --- a/tests/test_split_transform.py +++ b/tests/test_grid_split.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.transforms import Split +from monai.transforms import GridSplit from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) @@ -63,17 +63,17 @@ TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) -class TestSplit(unittest.TestCase): +class TestGridSplit(unittest.TestCase): @parameterized.expand(TEST_SINGLE) def test_split_patch_single_call(self, in_type, input_parameters, image, expected): input_image = in_type(image) - splitter = Split(**input_parameters) + splitter = GridSplit(**input_parameters) output = splitter(input_image) assert_allclose(output, expected, type_test=False) @parameterized.expand(TEST_MULTIPLE) def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): - splitter = Split(**input_parameters) + splitter = GridSplit(**input_parameters) for image, expected in zip(img_list, expected_list): input_image = in_type(image) output = splitter(input_image) diff --git a/tests/test_split_transform_dict.py b/tests/test_grid_splitd.py similarity index 95% rename from tests/test_split_transform_dict.py rename to tests/test_grid_splitd.py index 32fda4181c..f325a16946 100644 --- a/tests/test_split_transform_dict.py +++ b/tests/test_grid_splitd.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.transforms import SplitDict +from monai.transforms import GridSplitd from tests.utils import TEST_NDARRAYS, assert_allclose A11 = torch.randn(3, 2, 2) @@ -75,19 +75,19 @@ TEST_MULTIPLE.append([p, *TEST_CASE_MC_2]) -class TestSplitDict(unittest.TestCase): +class TestGridSplitd(unittest.TestCase): @parameterized.expand(TEST_SINGLE) def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected): input_dict = {} for k, v in img_dict.items(): input_dict[k] = in_type(v) - splitter = SplitDict(**input_parameters) + splitter = GridSplitd(**input_parameters) output = splitter(input_dict)[input_parameters["keys"]] assert_allclose(output, expected, type_test=False) @parameterized.expand(TEST_MULTIPLE) def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list): - splitter = SplitDict(**input_parameters) + splitter = GridSplitd(**input_parameters) for img_dict, expected in zip(img_list, expected_list): input_dict = {} for k, v in img_dict.items(): From bf55c0518ce75d2b159c7a860f519ba5bb23cf19 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 22 Apr 2022 16:01:27 +0000 Subject: [PATCH 10/14] Add example and size check Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index da62e3ffc2..ba4f252287 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2475,6 +2475,11 @@ class GridSplit(Transform): If it's an integer, the value will be repeated for each dimension. The default is None, where the patch size will be inferred from the grid shape. + Example: + Given a image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2), + it will return a Tensor or array with the side of (4, 3, 5, 5) + Here, if the size is provided, the returned shape will be (4, 3, size, size) + Note: This transform currently support only image with two spatial dimensions. """ @@ -2532,7 +2537,9 @@ def _get_params(self, image_size: Union[Sequence[int], np.ndarray]): The size of the input image """ if self.size is not None: - # Set the requested size + # Set the split size to the given default size + if self.size > image_size: + raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})") split_size = self.size else: # infer each sub-image size from the image size and the grid From 630eb2a4eb25fd35f42acee16b6c6f9efdbbd5ce Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 22 Apr 2022 16:09:54 +0000 Subject: [PATCH 11/14] Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/transforms.rst | 4 ++-- monai/transforms/spatial/array.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6eb5fbf4dd..a93c48984c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -738,7 +738,7 @@ Spatial :special-members: __call__ `GridSplit` -"""""""" +""""""""""" .. autoclass:: GridSplit :members: :special-members: __call__ @@ -1514,7 +1514,7 @@ Spatial (Dict) :special-members: __call__ `GridSplitd` -""""""""" +"""""""""""" .. autoclass:: GridSplitd :members: :special-members: __call__ diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ba4f252287..0b22a9ed60 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2476,9 +2476,9 @@ class GridSplit(Transform): The default is None, where the patch size will be inferred from the grid shape. Example: - Given a image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2), - it will return a Tensor or array with the side of (4, 3, 5, 5) - Here, if the size is provided, the returned shape will be (4, 3, size, size) + Given an image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2), + it will return a Tensor or array with the size of (4, 3, 5, 5). + Here, if the `size` is provided, the returned shape will be (4, 3, size, size) Note: This transform currently support only image with two spatial dimensions. """ From cc97cc1a565c5490cf80025cc6d0a173c3530eb7 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 22 Apr 2022 16:13:33 +0000 Subject: [PATCH 12/14] Revert references to new wsireader Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/apps/pathology/data/datasets.py | 2 +- monai/apps/pathology/metrics/lesion_froc.py | 2 +- tests/test_wsireader.py | 71 +++++++++++++-------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index 756223a784..71f3214ea4 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -16,7 +16,7 @@ import numpy as np from monai.data import Dataset, SmartCacheDataset -from monai.data.wsi_reader import WSIReader +from monai.data.image_reader import WSIReader from monai.utils import ensure_tuple_rep __all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset", "MaskedInferenceWSIDataset"] diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py index e48f2128fe..6073bd0cda 100644 --- a/monai/apps/pathology/metrics/lesion_froc.py +++ b/monai/apps/pathology/metrics/lesion_froc.py @@ -14,7 +14,7 @@ import numpy as np from monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask -from monai.data.wsi_reader import WSIReader +from monai.data.image_reader import WSIReader from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score from monai.utils import min_version, optional_import diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 7b288f6040..6ee02143b8 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import DataLoader, Dataset -from monai.data.wsi_reader import WSIReader +from monai.data.image_reader import WSIReader from monai.transforms import Compose, LoadImaged, ToTensord from monai.utils import first, optional_import from monai.utils.enums import PostFix @@ -57,17 +57,29 @@ ] TEST_CASE_3 = [ - [FILE_PATH, FILE_PATH], - {"location": (0, 0), "size": (2, 1), "level": 2}, - np.concatenate( + FILE_PATH, + {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 2}, + np.array( [ - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), - np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), - ], - axis=0, + [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], + [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], + ] ), ] +TEST_CASE_4 = [ + FILE_PATH, + {"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 1}, + np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), +] + +TEST_CASE_5 = [ + FILE_PATH, + {"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)}, + np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]), +] + + TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW @@ -126,7 +138,7 @@ def test_read_whole_image(self, file_path, level, expected_shape): img = reader.get_data(img_obj)[0] self.assertTupleEqual(img.shape, expected_shape) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5]) def test_read_region(self, file_path, patch_info, expected_img): kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} reader = WSIReader(self.backend, **kwargs) @@ -143,22 +155,17 @@ def test_read_region(self, file_path, patch_info, expected_img): self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) - @parameterized.expand([TEST_CASE_3]) - def test_read_region_multi_wsi(self, file_path, patch_info, expected_img): - kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {} - reader = WSIReader(self.backend, **kwargs) - img_obj = reader.read(file_path, **kwargs) - if self.backend == "tifffile": - with self.assertRaises(ValueError): - reader.get_data(img_obj, **patch_info)[0] - else: - # Read twice to check multiple calls - img = reader.get_data(img_obj, **patch_info)[0] - img2 = reader.get_data(img_obj, **patch_info)[0] - self.assertTupleEqual(img.shape, img2.shape) - self.assertIsNone(assert_array_equal(img, img2)) - self.assertTupleEqual(img.shape, expected_img.shape) - self.assertIsNone(assert_array_equal(img, expected_img)) + @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) + def test_read_patches(self, file_path, patch_info, expected_img): + reader = WSIReader(self.backend) + with reader.read(file_path) as img_obj: + if self.backend == "tifffile": + with self.assertRaises(ValueError): + reader.get_data(img_obj, **patch_info)[0] + else: + img = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) @skipUnless(has_tiff, "Requires tifffile.") @@ -214,5 +221,19 @@ def setUpClass(cls): cls.backend = "cucim" +@skipUnless(has_osl, "Requires OpenSlide") +class TestOpenSlide(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "openslide" + + +@skipUnless(has_tiff, "Requires TiffFile") +class TestTiffFile(WSIReaderTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "tifffile" + + if __name__ == "__main__": unittest.main() From ab23363aa38b3c91cbb2f08030b160fe64326353 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 22 Apr 2022 16:15:33 +0000 Subject: [PATCH 13/14] Add missing comma Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0b22a9ed60..f2226c3107 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -66,7 +66,8 @@ "Orientation", "Flip", "GridDistortion", - "GridSplit" "Resize", + "GridSplit", + "Resize", "Rotate", "Zoom", "Rotate90", From 1b68e48751763728f4a5a8304a3be7ea6164b35f Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 22 Apr 2022 16:45:31 +0000 Subject: [PATCH 14/14] Fix size Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f2226c3107..6b67762b95 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -2539,7 +2539,7 @@ def _get_params(self, image_size: Union[Sequence[int], np.ndarray]): """ if self.size is not None: # Set the split size to the given default size - if self.size > image_size: + if any(self.size[i] > image_size[i] for i in range(len(self.grid))): raise ValueError("The image size ({image_size})is smaller than the requested split size ({self.size})") split_size = self.size else: