diff --git a/docs/source/data.rst b/docs/source/data.rst index aeeba539c5..7f59e587ec 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -312,6 +312,11 @@ PatchWSIDataset .. autoclass:: monai.data.PatchWSIDataset :members: +SlidingPatchWSIDataset +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.SlidingPatchWSIDataset + :members: + Bounding box ------------ .. automodule:: monai.data.box_utils diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 8ab4742a75..40ee3cfc29 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -83,6 +83,7 @@ get_valid_patch_size, is_supported_format, iter_patch, + iter_patch_position, iter_patch_slices, json_hashing, list_data_collate, @@ -103,7 +104,7 @@ worker_init_fn, zoom_affine, ) -from .wsi_datasets import PatchWSIDataset +from .wsi_datasets import PatchWSIDataset, SlidingPatchWSIDataset from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, WSIReader with contextlib.suppress(BaseException): diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 33497b5a68..2e389d9e0b 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -73,6 +73,7 @@ def __call__(self, array: np.ndarray): array, patch_size=self.patch_size, # type: ignore start_pos=self.start_pos, + overlap=0.0, copy_back=False, mode=self.mode, **self.pad_opts, diff --git a/monai/data/utils.py b/monai/data/utils.py index ca91006817..e56d8b86e9 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -67,6 +67,7 @@ "get_valid_patch_size", "is_supported_format", "iter_patch", + "iter_patch_position", "iter_patch_slices", "json_hashing", "list_data_collate", @@ -123,32 +124,36 @@ def get_random_patch( def iter_patch_slices( - dims: Sequence[int], patch_size: Union[Sequence[int], int], start_pos: Sequence[int] = () + image_size: Sequence[int], + patch_size: Union[Sequence[int], int], + start_pos: Sequence[int] = (), + overlap: Union[Sequence[float], float] = 0.0, + padded: bool = True, ) -> Generator[Tuple[slice, ...], None, None]: """ - Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `dims`. The - iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each - patch is chosen in a contiguous grid using a first dimension as least significant ordering. + Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `image_size`. + The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each + patch is chosen in a contiguous grid using a rwo-major ordering. Args: - dims: dimensions of array to iterate over + image_size: dimensions of array to iterate over patch_size: size of patches to generate slices for, 0 or None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. + padded: if the image is padded so the patches can go beyond the borders. Defaults to False. Yields: Tuples of slice objects defining each patch """ - # ensure patchSize and startPos are the right length - ndim = len(dims) - patch_size_ = get_valid_patch_size(dims, patch_size) - start_pos = ensure_tuple_size(start_pos, ndim) + # ensure patch_size has the right length + patch_size_ = get_valid_patch_size(image_size, patch_size) - # collect the ranges to step over each dimension - ranges = tuple(starmap(range, zip(start_pos, dims, patch_size_))) - - # choose patches by applying product to the ranges - for position in product(*ranges): + # create slices based on start position of each patch + for position in iter_patch_position( + image_size=image_size, patch_size=patch_size_, start_pos=start_pos, overlap=overlap, padded=padded + ): yield tuple(slice(s, s + p) for s, p in zip(position, patch_size_)) @@ -192,10 +197,54 @@ def dense_patch_slices( return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] +def iter_patch_position( + image_size: Sequence[int], + patch_size: Union[Sequence[int], int], + start_pos: Sequence[int] = (), + overlap: Union[Sequence[float], float] = 0.0, + padded: bool = False, +): + """ + Yield successive tuples of upper left corner of patches of size `patch_size` from an array of dimensions `image_size`. + The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each + patch is chosen in a contiguous grid using a rwo-major ordering. + + Args: + image_size: dimensions of array to iterate over + patch_size: size of patches to generate slices for, 0 or None selects whole dimension + start_pos: starting position in the array, default is 0 for each dimension + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. + padded: if the image is padded so the patches can go beyond the borders. Defaults to False. + + Yields: + Tuples of positions defining the upper left corner of each patch + """ + + # ensure patchSize and startPos are the right length + ndim = len(image_size) + patch_size_ = get_valid_patch_size(image_size, patch_size) + start_pos = ensure_tuple_size(start_pos, ndim) + overlap = ensure_tuple_rep(overlap, ndim) + + # calculate steps, which depends on the amount of overlap + steps = tuple(round(p * (1.0 - o)) for p, o in zip(patch_size_, overlap)) + + # calculate the last starting location (depending on the padding) + end_pos = image_size if padded else tuple(s - round(p) + 1 for s, p in zip(image_size, patch_size_)) + + # collect the ranges to step over each dimension + ranges = starmap(range, zip(start_pos, end_pos, steps)) + + # choose patches by applying product to the ranges + return product(*ranges) + + def iter_patch( arr: np.ndarray, patch_size: Union[Sequence[int], int] = 0, start_pos: Sequence[int] = (), + overlap: Union[Sequence[float], float] = 0.0, copy_back: bool = True, mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP, **pad_opts: Dict, @@ -209,6 +258,8 @@ def iter_patch( arr: array to iterate over patch_size: size of patches to generate slices for, 0 or None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} @@ -243,7 +294,7 @@ def iter_patch( # patches which are only in the padded regions iter_size = tuple(s + p for s, p in zip(arr.shape, patch_size_)) - for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded): + for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded, overlap): # compensate original image padding coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, patch_size_)) yield arrpad[slices], np.asarray(coords_no_pad) # data and coords (in numpy; works with torch loader) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 665cbd196c..6fe5435d57 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -15,11 +15,12 @@ import numpy as np from monai.data import Dataset +from monai.data.utils import iter_patch_position from monai.data.wsi_reader import BaseWSIReader, WSIReader -from monai.transforms import apply_transform +from monai.transforms import Randomizable, apply_transform from monai.utils import ensure_tuple_rep -__all__ = ["PatchWSIDataset"] +__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset"] class PatchWSIDataset(Dataset): @@ -137,3 +138,130 @@ def _transform(self, index: int): # Apply transforms and output output = {"image": image, "label": label, "metadata": metadata} return apply_transform(self.transform, output) if self.transform else output + + +class SlidingPatchWSIDataset(Randomizable, PatchWSIDataset): + """ + This dataset extracts patches from whole slide images (without loading the whole image) + It also reads labels for each patch and provides each patch with its associated class labels. + + Args: + data: the list of input samples including image, location, and label (see the note below for more details). + size: the size of patch to be extracted from the whole slide image. + level: the level at which the patches to be extracted (default to 0). + offset: the offset of image to extract patches (the starting position of the upper left patch). + offset_limits: if offset is set to "random", a tuple of integers defining the lower and upper limit of the + random offset for all dimensions, or a tuple of tuples that defines the limits for each dimension. + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. + transform: transforms to be executed on input data. + reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is + + - a string, it defines the backend of `monai.data.WSIReader`. + - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader, + - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. + + seed: random seed to randomly generate offsets. Defaults to 0. + kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class + + Note: + The input data has the following form as an example: + + .. code-block:: python + + [ + {"image": "path/to/image1.tiff"}, + {"image": "path/to/image2.tiff", "size": [20, 20], "level": 2} + ] + + """ + + def __init__( + self, + data: Sequence, + size: Optional[Union[int, Tuple[int, int]]] = None, + level: Optional[int] = None, + overlap: Union[Tuple[float, float], float] = 0.0, + offset: Union[Tuple[int, int], int, str] = (0, 0), + offset_limits: Optional[Union[Tuple[Tuple[int, int], Tuple[int, int]], Tuple[int, int]]] = None, + transform: Optional[Callable] = None, + reader="cuCIM", + seed: int = 0, + **kwargs, + ): + super().__init__(data=data, size=size, level=level, transform=transform, reader=reader, **kwargs) + self.overlap = overlap + self.set_random_state(seed) + # Set the offset config + self.random_offset = False + if isinstance(offset, str): + if offset == "random": + self.random_offset = True + self.offset_limits: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] + if offset_limits is None: + self.offset_limits = None + elif isinstance(offset_limits, tuple): + if isinstance(offset_limits[0], int): + self.offset_limits = (offset_limits, offset_limits) + elif isinstance(offset_limits[0], tuple): + self.offset_limits = offset_limits + else: + ValueError( + "The offset limits should be either a tuple of integers or tuple of tuple of integers." + ) + else: + ValueError("The offset limits should be a tuple.") + else: + ValueError( + f'Invalid string for offset "{offset}". It should be either "random" as a string,' + "an integer, or a tuple of integers defining the offset." + ) + else: + self.offset = ensure_tuple_rep(offset, 2) + + # Create single sample for each patch (in a sliding window manner) + self.data = [] + for sample in data: + sliding_samples = self._evaluate_patch_coordinates(sample) + self.data.extend(sliding_samples) + + def _get_offset(self, sample): + if self.random_offset: + if self.offset_limits is None: + offset_limits = tuple((-s, s) for s in self._get_size(sample)) + else: + offset_limits = self.offset_limits + return tuple(self.R.randint(low, high) for low, high in offset_limits) + return self.offset + + def _evaluate_patch_coordinates(self, sample): + """Define the location for each patch based on sliding-window approach""" + patch_size = self._get_size(sample) + level = self._get_level(sample) + start_pos = self._get_offset(sample) + + wsi_obj = self._get_wsi_object(sample) + wsi_size = self.wsi_reader.get_size(wsi_obj, 0) + downsample = self.wsi_reader.get_downsample_ratio(wsi_obj, level) + patch_size_ = tuple(p * downsample for p in patch_size) # patch size at level 0 + locations = list( + iter_patch_position( + image_size=wsi_size, patch_size=patch_size_, start_pos=start_pos, overlap=self.overlap, padded=False + ) + ) + sample["size"] = patch_size + sample["level"] = level + n_patches = len(locations) + return [{**sample, "location": loc, "num_patches": n_patches} for loc in locations] + + def _get_location(self, sample: Dict): + return sample["location"] + + def _transform(self, index: int): + # Get a single entry of data + sample: Dict = self.data[index] + # Extract patch image and associated metadata + image, metadata = self._get_data(sample) + # Create put all patch information together and apply transforms + patch = {"image": image, "metadata": metadata} + return apply_transform(self.transform, patch) if self.transform else patch diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index b29ac3848f..fdf7de3d63 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -85,6 +85,18 @@ def get_level_count(self, wsi) -> int: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod + def get_downsample_ratio(self, wsi, level: int) -> float: + """ + Returns the down-sampling ratio of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod def get_file_path(self, wsi) -> str: """Return the file path for the WSI object""" @@ -290,6 +302,17 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]: """ return self.reader.get_size(wsi, level) + def get_downsample_ratio(self, wsi, level: int) -> float: + """ + Returns the down-sampling ratio of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return self.reader.get_downsample_ratio(wsi, level) + def get_file_path(self, wsi) -> str: """Return the file path for the WSI object""" return self.reader.get_file_path(wsi) @@ -369,6 +392,18 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0]) + @staticmethod + def get_downsample_ratio(wsi, level: int) -> float: + """ + Returns the down-sampling ratio of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return wsi.resolutions["level_downsamples"][level] # type: ignore + def get_file_path(self, wsi) -> str: """Return the file path for the WSI object""" return str(abspath(wsi.path)) @@ -475,6 +510,18 @@ def get_size(wsi, level: int) -> Tuple[int, int]: """ return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0]) + @staticmethod + def get_downsample_ratio(wsi, level: int) -> float: + """ + Returns the down-sampling ratio of the whole slide image at a given level. + + Args: + wsi: a whole slide image object loaded from a file + level: the level number where the size is calculated + + """ + return wsi.level_downsamples[level] # type: ignore + def get_file_path(self, wsi) -> str: """Return the file path for the WSI object""" return str(abspath(wsi._filename)) diff --git a/tests/test_sliding_patch_wsi_dataset.py b/tests/test_sliding_patch_wsi_dataset.py new file mode 100644 index 0000000000..1eaa0292c5 --- /dev/null +++ b/tests/test_sliding_patch_wsi_dataset.py @@ -0,0 +1,258 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest import skipUnless + +import numpy as np +from parameterized import parameterized + +from monai.data import SlidingPatchWSIDataset +from monai.utils import optional_import, set_determinism +from tests.utils import download_url_or_skip_test, testing_data_config + +set_determinism(0) + +cucim, has_cucim = optional_import("cucim") +has_cucim = has_cucim and hasattr(cucim, "CuImage") +openslide, has_osl = optional_import("openslide") +imwrite, has_tiff = optional_import("tifffile", name="imwrite") +_, has_codec = optional_import("imagecodecs") +has_tiff = has_tiff and has_codec + + +FILE_KEY = "wsi_img" +FILE_URL = testing_data_config("images", FILE_KEY, "url") +base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" +FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) + +FILE_PATH_SMALL_0 = os.path.join(os.path.dirname(__file__), "testing_data", "temp_wsi_inference_0.tiff") +FILE_PATH_SMALL_1 = os.path.join(os.path.dirname(__file__), "testing_data", "temp_wsi_inference_1.tiff") +ARRAY_SMALL_0 = np.random.randint(low=0, high=255, size=(3, 4, 4), dtype=np.uint8) +ARRAY_SMALL_1 = np.random.randint(low=0, high=255, size=(3, 5, 5), dtype=np.uint8) + +TEST_CASE_SMALL_0 = [ + {"data": [{"image": FILE_PATH_SMALL_0, "level": 0}], "size": (2, 2)}, + [ + {"image": ARRAY_SMALL_0[:, :2, :2]}, + {"image": ARRAY_SMALL_0[:, :2, 2:]}, + {"image": ARRAY_SMALL_0[:, 2:, :2]}, + {"image": ARRAY_SMALL_0[:, 2:, 2:]}, + ], +] + +TEST_CASE_SMALL_1 = [ + {"data": [{"image": FILE_PATH_SMALL_0, "level": 0, "size": (2, 2)}]}, + [ + {"image": ARRAY_SMALL_0[:, :2, :2]}, + {"image": ARRAY_SMALL_0[:, :2, 2:]}, + {"image": ARRAY_SMALL_0[:, 2:, :2]}, + {"image": ARRAY_SMALL_0[:, 2:, 2:]}, + ], +] + +TEST_CASE_SMALL_2 = [ + {"data": [{"image": FILE_PATH_SMALL_0, "level": 0}], "size": (2, 2), "overlap": 0.5}, + [ + {"image": ARRAY_SMALL_0[:, 0:2, 0:2]}, + {"image": ARRAY_SMALL_0[:, 0:2, 1:3]}, + {"image": ARRAY_SMALL_0[:, 0:2, 2:4]}, + {"image": ARRAY_SMALL_0[:, 1:3, 0:2]}, + {"image": ARRAY_SMALL_0[:, 1:3, 1:3]}, + {"image": ARRAY_SMALL_0[:, 1:3, 2:4]}, + {"image": ARRAY_SMALL_0[:, 2:4, 0:2]}, + {"image": ARRAY_SMALL_0[:, 2:4, 1:3]}, + {"image": ARRAY_SMALL_0[:, 2:4, 2:4]}, + ], +] + +TEST_CASE_SMALL_3 = [ + {"data": [{"image": FILE_PATH_SMALL_0, "level": 0}], "size": (3, 3), "overlap": 2.0 / 3.0}, + [ + {"image": ARRAY_SMALL_0[:, :3, :3]}, + {"image": ARRAY_SMALL_0[:, :3, 1:]}, + {"image": ARRAY_SMALL_0[:, 1:, :3]}, + {"image": ARRAY_SMALL_0[:, 1:, 1:]}, + ], +] + +TEST_CASE_SMALL_4 = [ + {"data": [{"image": FILE_PATH_SMALL_0, "level": 0}, {"image": FILE_PATH_SMALL_1, "level": 0}], "size": (2, 2)}, + [ + {"image": ARRAY_SMALL_0[:, 0:2, 0:2]}, + {"image": ARRAY_SMALL_0[:, 0:2, 2:4]}, + {"image": ARRAY_SMALL_0[:, 2:4, 0:2]}, + {"image": ARRAY_SMALL_0[:, 2:4, 2:4]}, + {"image": ARRAY_SMALL_1[:, 0:2, 0:2]}, + {"image": ARRAY_SMALL_1[:, 0:2, 2:4]}, + {"image": ARRAY_SMALL_1[:, 2:4, 0:2]}, + {"image": ARRAY_SMALL_1[:, 2:4, 2:4]}, + ], +] + +TEST_CASE_SMALL_5 = [ + { + "data": [ + {"image": FILE_PATH_SMALL_0, "level": 0, "size": (2, 2)}, + {"image": FILE_PATH_SMALL_1, "level": 0, "size": (3, 3)}, + ] + }, + [ + {"image": ARRAY_SMALL_0[:, 0:2, 0:2]}, + {"image": ARRAY_SMALL_0[:, 0:2, 2:4]}, + {"image": ARRAY_SMALL_0[:, 2:4, 0:2]}, + {"image": ARRAY_SMALL_0[:, 2:4, 2:4]}, + {"image": ARRAY_SMALL_1[:, 0:3, 0:3]}, + ], +] + +TEST_CASE_SMALL_6 = [ + { + "data": [ + {"image": FILE_PATH_SMALL_0, "level": 1, "size": (1, 1)}, + {"image": FILE_PATH_SMALL_1, "level": 2, "size": (4, 4)}, + ], + "size": (2, 2), + "level": 0, + }, + [ + {"image": ARRAY_SMALL_0[:, 0:2, 0:2]}, + {"image": ARRAY_SMALL_0[:, 0:2, 2:4]}, + {"image": ARRAY_SMALL_0[:, 2:4, 0:2]}, + {"image": ARRAY_SMALL_0[:, 2:4, 2:4]}, + {"image": ARRAY_SMALL_1[:, 0:2, 0:2]}, + {"image": ARRAY_SMALL_1[:, 0:2, 2:4]}, + {"image": ARRAY_SMALL_1[:, 2:4, 0:2]}, + {"image": ARRAY_SMALL_1[:, 2:4, 2:4]}, + ], +] + + +TEST_CASE_SMALL_7 = [ + {"data": [{"image": FILE_PATH_SMALL_0, "level": 0, "size": (2, 2)}], "offset": (1, 0)}, + [{"image": ARRAY_SMALL_0[:, 1:3, :2]}, {"image": ARRAY_SMALL_0[:, 1:3, 2:]}], +] + +TEST_CASE_SMALL_8 = [ + {"data": [{"image": FILE_PATH_SMALL_0, "level": 0, "size": (2, 2)}], "offset": "random", "offset_limits": (0, 2)}, + [{"image": ARRAY_SMALL_0[:, 1:3, :2]}, {"image": ARRAY_SMALL_0[:, 1:3, 2:]}], +] + +TEST_CASE_SMALL_9 = [ + { + "data": [{"image": FILE_PATH_SMALL_0, "level": 0, "size": (2, 2)}], + "offset": "random", + "offset_limits": ((0, 3), (0, 2)), + }, + [{"image": ARRAY_SMALL_0[:, :2, 1:3]}, {"image": ARRAY_SMALL_0[:, 2:, 1:3]}], +] + +TEST_CASE_LARGE_0 = [ + {"data": [{"image": FILE_PATH, "level": 8, "size": (64, 50)}]}, + [ + {"step_loc": (0, 0), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (0, 1), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (0, 2), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (1, 0), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (1, 1), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (1, 2), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + ], +] + +TEST_CASE_LARGE_1 = [ + { + "data": [ + {"image": FILE_PATH, "level": 8, "size": (64, 50)}, + {"image": FILE_PATH, "level": 7, "size": (125, 110)}, + ] + }, + [ + {"step_loc": (0, 0), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (0, 1), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (0, 2), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (1, 0), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (1, 1), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (1, 2), "size": (64, 50), "level": 8, "ratio": 257.06195068359375}, + {"step_loc": (0, 0), "size": (125, 110), "level": 7, "ratio": 128.10186767578125}, + {"step_loc": (0, 1), "size": (125, 110), "level": 7, "ratio": 128.10186767578125}, + {"step_loc": (0, 2), "size": (125, 110), "level": 7, "ratio": 128.10186767578125}, + {"step_loc": (1, 0), "size": (125, 110), "level": 7, "ratio": 128.10186767578125}, + {"step_loc": (1, 1), "size": (125, 110), "level": 7, "ratio": 128.10186767578125}, + {"step_loc": (1, 2), "size": (125, 110), "level": 7, "ratio": 128.10186767578125}, + ], +] + + +@skipUnless(has_cucim or has_tiff, "Requires cucim, openslide, or tifffile!") +def setUpModule(): # noqa: N802 + for info in [(ARRAY_SMALL_0, FILE_PATH_SMALL_0), (ARRAY_SMALL_1, FILE_PATH_SMALL_1)]: + array = info[0].transpose([1, 2, 0]) + imwrite(info[1], array, shape=array.shape, photometric="rgb") + hash_type = testing_data_config("images", FILE_KEY, "hash_type") + hash_val = testing_data_config("images", FILE_KEY, "hash_val") + download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) + + +class SlidingPatchWSIDatasetTests: + class Tests(unittest.TestCase): + backend = None + + @parameterized.expand( + [ + TEST_CASE_SMALL_0, + TEST_CASE_SMALL_1, + TEST_CASE_SMALL_2, + TEST_CASE_SMALL_3, + TEST_CASE_SMALL_4, + TEST_CASE_SMALL_5, + TEST_CASE_SMALL_6, + TEST_CASE_SMALL_7, + TEST_CASE_SMALL_8, + TEST_CASE_SMALL_9, + ] + ) + def test_read_patches(self, input_parameters, expected): + if self.backend == "openslide": + return + dataset = SlidingPatchWSIDataset(reader=self.backend, **input_parameters) + self.assertEqual(len(dataset), len(expected)) + for i, sample in enumerate(dataset): + self.assertTupleEqual(sample["image"].shape, expected[i]["image"].shape) + + @parameterized.expand([TEST_CASE_LARGE_0, TEST_CASE_LARGE_1]) + def test_read_patches_large(self, input_parameters, expected): + dataset = SlidingPatchWSIDataset(reader=self.backend, **input_parameters) + self.assertEqual(len(dataset), len(expected)) + for i, sample in enumerate(dataset): + self.assertEqual(sample["metadata"]["patch"]["level"], expected[i]["level"]) + self.assertTupleEqual(sample["metadata"]["patch"]["size"], expected[i]["size"]) + steps = [round(expected[i]["ratio"] * s) for s in expected[i]["size"]] + expected_location = tuple(expected[i]["step_loc"][j] * steps[j] for j in range(len(steps))) + self.assertTupleEqual(sample["metadata"]["patch"]["location"], expected_location) + + +@skipUnless(has_cucim, "Requires cucim") +class TestSlidingPatchWSIDatasetCuCIM(SlidingPatchWSIDatasetTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "cucim" + + +@skipUnless(has_osl, "Requires openslide") +class TestSlidingPatchWSIDatasetOpenSlide(SlidingPatchWSIDatasetTests.Tests): + @classmethod + def setUpClass(cls): + cls.backend = "openslide" + + +if __name__ == "__main__": + unittest.main()