diff --git a/docs/requirements.txt b/docs/requirements.txt index d046bc53cf..cd06166359 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,6 +4,7 @@ pytorch-ignite==0.4.2 numpy>=1.17 itk>=5.0 nibabel +openslide-python==1.1.2 parameterized scikit-image>=0.14.2 tensorboard diff --git a/docs/source/data.rst b/docs/source/data.rst index 11609964c3..eed4b30ded 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -105,6 +105,10 @@ PILReader .. autoclass:: PILReader :members: +WSIReader +~~~~~~~~~ +.. autoclass:: WSIReader + :members: Nifti format handling --------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 3dd0a980ef..54ee7908f4 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -24,7 +24,7 @@ from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties from .grid_dataset import GridPatchDataset, PatchDataset from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index dfbdaf5b41..76bf1817dc 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -19,26 +19,31 @@ from monai.config import DtypeLike, KeysCollection from monai.data.utils import correct_nifti_header_if_necessary +from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import ensure_tuple, optional_import from .utils import is_supported_format if TYPE_CHECKING: + import cuimage import itk # type: ignore import nibabel as nib + import openslide from itk import Image # type: ignore from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - has_itk = has_nib = has_pil = True + has_itk = has_nib = has_pil = has_cux = has_osl = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) Image, _ = optional_import("itk", allow_namespace_pkg=True, name="Image") nib, has_nib = optional_import("nibabel") Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") PILImage, has_pil = optional_import("PIL.Image") + cuimage, has_cux = optional_import("cuimage") + openslide, has_osl = optional_import("openslide") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] class ImageReader(ABC): @@ -264,10 +269,10 @@ def _get_affine(self, img) -> np.ndarray: origin = np.asarray(img.GetOrigin()) direction = np.asarray(direction) - affine = np.eye(direction.shape[0] + 1) + affine: np.ndarray = np.eye(direction.shape[0] + 1) affine[(slice(-1), slice(-1))] = direction @ np.diag(spacing) affine[(slice(-1), -1)] = origin - return np.asarray(affine) + return affine def _get_spatial_shape(self, img) -> np.ndarray: """ @@ -626,3 +631,151 @@ def _get_spatial_shape(self, img) -> np.ndarray: """ # the img data should have no channel dim or the last dim is channel return np.asarray((img.width, img.height)) + + +class WSIReader(ImageReader): + """ + Read whole slide imaging and extract patches + + """ + + def __init__(self, reader_lib: str = "cuClaraImage"): + super().__init__() + self.reader_lib = reader_lib.lower() + if self.reader_lib == "openslide": + self.wsi_reader = openslide.OpenSlide + print("> OpenSlide is being used.") + elif self.reader_lib == "cuclaraimage": + self.wsi_reader = cuimage.CuImage + print("> CuImage is being used.") + else: + raise ValueError('`reader_lib` should be either "cuClaraImage" or "OpenSlide"') + + def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + """ + Verify whether the specified file or files format is supported by WSI reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + """ + return is_supported_format(filename, ["tif", "tiff"]) + + def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + """ + Read image data from specified file or files. + Note that the returned object is CuImage or list of CuImage objects. + + Args: + data: file name or a list of file names to read. + + """ + img_: List = [] + + filenames: Sequence[str] = ensure_tuple(data) + for name in filenames: + img = self.wsi_reader(name) + if self.reader_lib == "openslide": + img.shape = (img.dimensions[1], img.dimensions[0], 3) + img_.append(img) + + return img_ if len(filenames) > 1 else img_[0] + + def get_data( + self, + img, + location: Tuple[int, int] = (0, 0), + size: Optional[Tuple[int, int]] = None, + level: int = 0, + dtype: DtypeLike = np.uint8, + grid_shape: Tuple[int, int] = (1, 1), + patch_size: Optional[int] = None, + ): + """ + Extract regions as numpy array from WSI image and return them. + + Args: + img: a WSIReader image object loaded from a file, or list of CuImage objects + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame, + or list of tuples (default=(0, 0)) + size: (height, width) tuple giving the region size, or list of tuples (default to full image size) + This is the size of image at the given level (`level`) + level: the level number, or list of level numbers (default=0) + dtype: the data type of output image + grid_shape: (row, columns) tuple define a grid to extract patches on that + patch_size: (heigsht, width) the size of extracted patches at the given level + """ + if size is None: + if location == (0, 0): + # the maximum size is set to WxH + size = (img.shape[0] // (2 ** level), img.shape[1] // (2 ** level)) + print(f"Reading the whole image at level={level} with shape={size}") + else: + raise ValueError("Size need to be provided to extract the region!") + + region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) + + metadata: Dict = {} + metadata["spatial_shape"] = size + metadata["original_channel_dim"] = -1 + region = EnsureChannelFirst()(region, metadata) + + if patch_size is None: + patches = region + else: + patches = self._extract_patches( + region, patch_size=(patch_size, patch_size), grid_shape=grid_shape, dtype=dtype + ) + + return patches, metadata + + def _extract_region( + self, + img_obj, + size: Tuple[int, int], + location: Tuple[int, int] = (0, 0), + level: int = 0, + dtype: DtypeLike = np.uint8, + ): + # reverse the order of dimensions for size and location to be compatible with image shape + size = size[::-1] + location = location[::-1] + region = img_obj.read_region(location=location, size=size, level=level) + if self.reader_lib == "openslide": + region = region.convert("RGB") + # convert to numpy + region = np.asarray(region, dtype=dtype) + + return region + + def _extract_patches( + self, + region: np.ndarray, + grid_shape: Tuple[int, int] = (1, 1), + patch_size: Optional[Tuple[int, int]] = None, + dtype: DtypeLike = np.uint8, + ): + if patch_size is None and grid_shape == (1, 1): + return region + + n_patches = grid_shape[0] * grid_shape[1] + region_size = region.shape[1:] + + if patch_size is None: + patch_size = (region_size[0] // grid_shape[0], region_size[1] // grid_shape[1]) + + # split the region into patches on the grid and center crop them to patch size + flat_patch_grid = np.zeros((n_patches, 3, patch_size[0], patch_size[1]), dtype=dtype) + start_points = [ + np.round(region_size[i] * (0.5 + np.arange(grid_shape[i])) / grid_shape[i] - patch_size[i] / 2).astype(int) + for i in range(2) + ] + idx = 0 + for y_start in start_points[1]: + for x_start in start_points[0]: + x_end = x_start + patch_size[0] + y_end = y_start + patch_size[1] + flat_patch_grid[idx] = region[:, x_start:x_end, y_start:y_end] + idx += 1 + + return flat_patch_grid diff --git a/requirements-dev.txt b/requirements-dev.txt index 2a43e63d73..3eeab474b6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -30,3 +30,4 @@ Sphinx==3.3.0 recommonmark==0.6.0 sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.0 +openslide-python==1.1.2 diff --git a/setup.cfg b/setup.cfg index ea61eadd92..f18b4610fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ all = torchvision itk>=5.0 tqdm>=4.47.0 + openslide-python==1.1.2 nibabel = nibabel skimage = @@ -54,6 +55,8 @@ lmdb = lmdb psutil = psutil +openslide = + openslide-python==1.1.2 [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/tests/test_cuimage_reader.py b/tests/test_cuimage_reader.py new file mode 100644 index 0000000000..7cdf692a30 --- /dev/null +++ b/tests/test_cuimage_reader.py @@ -0,0 +1,103 @@ +import os +import unittest +from unittest import skipUnless +from urllib import request + +import numpy as np +from numpy.testing import assert_array_equal +from parameterized import parameterized + +from monai.data.image_reader import WSIReader +from monai.utils import optional_import + +_, has_cui = optional_import("cuimage") + + +FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" +HEIGHT = 32914 +WIDTH = 46000 + +TEST_CASE_0 = [FILE_URL, (3, HEIGHT, WIDTH)] + +TEST_CASE_1 = [ + FILE_URL, + {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, + np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), +] + +TEST_CASE_2 = [ + FILE_URL, + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), +] + +TEST_CASE_3 = [ + FILE_URL, + { + "location": (0, 0), + "size": (8, 8), + "level": 2, + "grid_shape": (2, 1), + "patch_size": 2, + }, + np.array( + [ + [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], + [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], + ] + ), +] + +TEST_CASE_4 = [ + FILE_URL, + { + "location": (0, 0), + "size": (8, 8), + "level": 2, + "grid_shape": (2, 1), + "patch_size": 1, + }, + np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), +] + + +class TestCuClaraImageReader(unittest.TestCase): + @parameterized.expand([TEST_CASE_0]) + @skipUnless(has_cui, "Requires CuClaraImage") + def test_read_whole_image(self, file_url, expected_shape): + filename = self.camelyon_data_download(file_url) + reader = WSIReader("CuClaraImage") + img_obj = reader.read(filename) + img = reader.get_data(img_obj)[0] + self.assertTupleEqual(img.shape, expected_shape) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @skipUnless(has_cui, "Requires CuClaraImage") + def test_read_region(self, file_url, patch_info, expected_img): + filename = self.camelyon_data_download(file_url) + reader = WSIReader("CuClaraImage") + img_obj = reader.read(filename) + img = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) + @skipUnless(has_cui, "Requires CuClaraImage") + def test_read_patches(self, file_url, patch_info, expected_img): + filename = self.camelyon_data_download(file_url) + reader = WSIReader("CuClaraImage") + img_obj = reader.read(filename) + img = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + def camelyon_data_download(self, file_url): + filename = os.path.basename(file_url) + if not os.path.exists(filename): + print(f"Test image [{filename}] does not exist. Downloading...") + request.urlretrieve(file_url, filename) + return filename + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_openslide_reader.py b/tests/test_openslide_reader.py new file mode 100644 index 0000000000..e1f9187937 --- /dev/null +++ b/tests/test_openslide_reader.py @@ -0,0 +1,103 @@ +import os +import unittest +from unittest import skipUnless +from urllib import request + +import numpy as np +from numpy.testing import assert_array_equal +from parameterized import parameterized + +from monai.data.image_reader import WSIReader +from monai.utils import optional_import + +_, has_osl = optional_import("openslide") + + +FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" +HEIGHT = 32914 +WIDTH = 46000 + +TEST_CASE_0 = [FILE_URL, (3, HEIGHT, WIDTH)] + +TEST_CASE_1 = [ + FILE_URL, + {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0}, + np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), +] + +TEST_CASE_2 = [ + FILE_URL, + {"location": (0, 0), "size": (2, 1), "level": 2}, + np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), +] + +TEST_CASE_3 = [ + FILE_URL, + { + "location": (0, 0), + "size": (8, 8), + "level": 2, + "grid_shape": (2, 1), + "patch_size": 2, + }, + np.array( + [ + [[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]], + [[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]], + ] + ), +] + +TEST_CASE_4 = [ + FILE_URL, + { + "location": (0, 0), + "size": (8, 8), + "level": 2, + "grid_shape": (2, 1), + "patch_size": 1, + }, + np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), +] + + +class TestOpenSlideReader(unittest.TestCase): + @parameterized.expand([TEST_CASE_0]) + @skipUnless(has_osl, "Requires OpenSlide") + def test_read_whole_image(self, file_url, expected_shape): + filename = self.camelyon_data_download(file_url) + reader = WSIReader("OpenSlide") + img_obj = reader.read(filename) + img = reader.get_data(img_obj)[0] + self.assertTupleEqual(img.shape, expected_shape) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @skipUnless(has_osl, "Requires OpenSlide") + def test_read_region(self, file_url, patch_info, expected_img): + filename = self.camelyon_data_download(file_url) + reader = WSIReader("OpenSlide") + img_obj = reader.read(filename) + img = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) + @skipUnless(has_osl, "Requires OpenSlide") + def test_read_patches(self, file_url, patch_info, expected_img): + filename = self.camelyon_data_download(file_url) + reader = WSIReader("OpenSlide") + img_obj = reader.read(filename) + img = reader.get_data(img_obj, **patch_info)[0] + self.assertTupleEqual(img.shape, expected_img.shape) + self.assertIsNone(assert_array_equal(img, expected_img)) + + def camelyon_data_download(self, file_url): + filename = os.path.basename(file_url) + if not os.path.exists(filename): + print(f"Test image [{filename}] does not exist. Downloading...") + request.urlretrieve(file_url, filename) + return filename + + +if __name__ == "__main__": + unittest.main()