Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
753deca
Implement CuImageReader and OpenSlideReader
bhashemian Feb 4, 2021
ddbd6ab
Add unittests for CuImageReader
bhashemian Feb 4, 2021
7e77449
Add unittests for OpenSlideReader
bhashemian Feb 4, 2021
d3dbf7d
Merge branch 'master' into pathology_dataset
bhashemian Feb 4, 2021
c40b019
Sort imports
bhashemian Feb 4, 2021
f8b0962
Add correct boundaries
bhashemian Feb 5, 2021
e4dd37d
Merge branch 'pathology_dataset' of github.com:behxyz/MONAI into path…
bhashemian Feb 5, 2021
9a3e672
Add test cases for reading patches on a grid for CuImage
bhashemian Feb 5, 2021
b463310
Add patch whole slide imaging dataset for pathology
bhashemian Feb 5, 2021
4c735cb
Add test case for read patches for OpenSlide
bhashemian Feb 5, 2021
378893c
flake8 and few minor changes
bhashemian Feb 5, 2021
ec5261b
black
bhashemian Feb 5, 2021
ce01a9b
flake8
bhashemian Feb 5, 2021
51c1578
Add kwargs to CuImageReader and OpenSlideReader's read method
bhashemian Feb 8, 2021
714561a
Change the type hint from np.dtype to DTypeLike
bhashemian Feb 8, 2021
f6f5cf6
Merge branch 'master' into pathology_dataset
bhashemian Feb 8, 2021
642ee9b
Merge branch 'master' into pathology_dataset
bhashemian Feb 8, 2021
e83573d
Fix a bug
bhashemian Feb 8, 2021
1adf4ee
Merge branch 'master' into pathology_dataset
bhashemian Feb 22, 2021
097eb19
Implement WSIReader and unittests
bhashemian Feb 22, 2021
356e0d4
Minor updates
bhashemian Feb 22, 2021
27a04f6
Fix few typing issues
bhashemian Feb 23, 2021
9f09e49
Revert datasets
bhashemian Feb 23, 2021
4b9734f
Add shape property to openslide image object
bhashemian Feb 23, 2021
563314f
Add untittest for loading the whole image
bhashemian Feb 23, 2021
eb9655d
Update the whole image size
bhashemian Feb 23, 2021
71f9af4
Remove optional size
bhashemian Feb 23, 2021
3b98096
Remove optional dtype
bhashemian Feb 23, 2021
0076988
Remove _get_spatial_shape return type
bhashemian Feb 23, 2021
291846f
Reverse the orders of dimensions of `location`
bhashemian Feb 24, 2021
3ac7647
Change test cases to use smaller image and revese location's dimensions
bhashemian Feb 24, 2021
40a6f23
Merge branch 'master' into pathology_dataset
bhashemian Feb 24, 2021
00b7a55
Merge branch 'master' into pathology_dataset
bhashemian Feb 26, 2021
b851859
Replace the test TIFF and some upgrades
bhashemian Feb 26, 2021
0a99658
Update dependencies for OpenSlide
bhashemian Feb 26, 2021
dede661
Merge branch 'master' into pathology_dataset
bhashemian Mar 1, 2021
563a4fa
Update unittests for OpenSlide and CuImage
bhashemian Mar 1, 2021
9ee2200
Merge branch 'pathology_dataset' of pathology_dataset
bhashemian Mar 1, 2021
3ac12c3
Fix openslide dependency
bhashemian Mar 1, 2021
15c147d
Fix doc dependencies
bhashemian Mar 1, 2021
d9059ec
Merge branch 'master' into pathology_dataset
Nic-Ma Mar 3, 2021
c394ebe
Merge branch 'master' into pathology_dataset
bhashemian Mar 3, 2021
8a279c3
Minor changes
bhashemian Mar 3, 2021
c6171d1
Merge branch 'pathology_dataset' into pathology_dataset
bhashemian Mar 3, 2021
0082ac6
Merge branch 'master' into pathology_dataset
bhashemian Mar 3, 2021
22846f8
Merge branch 'master' into pathology_dataset
bhashemian Mar 4, 2021
c8750f0
Few variable name changes
bhashemian Mar 4, 2021
a440caf
Add EnsureChannelFirst
bhashemian Mar 4, 2021
d4ff431
Merge branch 'pathology_dataset' of github.com:behxyz/MONAI into path…
bhashemian Mar 4, 2021
652f046
Add metadata to WSIReader
bhashemian Mar 4, 2021
2ffdf58
Merge branch 'master' into pathology_dataset
bhashemian Mar 4, 2021
1f32f71
Merge branch 'master' into pathology_dataset
bhashemian Mar 5, 2021
2202f57
Merge branch 'master' into pathology_dataset
bhashemian Mar 5, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pytorch-ignite==0.4.2
numpy>=1.17
itk>=5.0
nibabel
openslide-python==1.1.2
parameterized
scikit-image>=0.14.2
tensorboard
Expand Down
4 changes: 4 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ PILReader
.. autoclass:: PILReader
:members:

WSIReader
~~~~~~~~~
.. autoclass:: WSIReader
:members:

Nifti format handling
---------------------
Expand Down
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties
from .grid_dataset import GridPatchDataset, PatchDataset
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .iterable_dataset import IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
161 changes: 157 additions & 4 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,31 @@

from monai.config import DtypeLike, KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import ensure_tuple, optional_import

from .utils import is_supported_format

if TYPE_CHECKING:
import cuimage
import itk # type: ignore
import nibabel as nib
import openslide
from itk import Image # type: ignore
from nibabel.nifti1 import Nifti1Image
from PIL import Image as PILImage

has_itk = has_nib = has_pil = True
has_itk = has_nib = has_pil = has_cux = has_osl = True
else:
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
Image, _ = optional_import("itk", allow_namespace_pkg=True, name="Image")
nib, has_nib = optional_import("nibabel")
Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image")
PILImage, has_pil = optional_import("PIL.Image")
cuimage, has_cux = optional_import("cuimage")
openslide, has_osl = optional_import("openslide")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"]
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"]


class ImageReader(ABC):
Expand Down Expand Up @@ -264,10 +269,10 @@ def _get_affine(self, img) -> np.ndarray:
origin = np.asarray(img.GetOrigin())

direction = np.asarray(direction)
affine = np.eye(direction.shape[0] + 1)
affine: np.ndarray = np.eye(direction.shape[0] + 1)
affine[(slice(-1), slice(-1))] = direction @ np.diag(spacing)
affine[(slice(-1), -1)] = origin
return np.asarray(affine)
return affine

def _get_spatial_shape(self, img) -> np.ndarray:
"""
Expand Down Expand Up @@ -626,3 +631,151 @@ def _get_spatial_shape(self, img) -> np.ndarray:
"""
# the img data should have no channel dim or the last dim is channel
return np.asarray((img.width, img.height))


class WSIReader(ImageReader):
"""
Read whole slide imaging and extract patches

"""

def __init__(self, reader_lib: str = "cuClaraImage"):
super().__init__()
self.reader_lib = reader_lib.lower()
if self.reader_lib == "openslide":
self.wsi_reader = openslide.OpenSlide
print("> OpenSlide is being used.")
elif self.reader_lib == "cuclaraimage":
self.wsi_reader = cuimage.CuImage
print("> CuImage is being used.")
else:
raise ValueError('`reader_lib` should be either "cuClaraImage" or "OpenSlide"')

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
"""
Verify whether the specified file or files format is supported by WSI reader.

Args:
filename: file name or a list of file names to read.
if a list of files, verify all the suffixes.
"""
return is_supported_format(filename, ["tif", "tiff"])

def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs):
"""
Read image data from specified file or files.
Note that the returned object is CuImage or list of CuImage objects.

Args:
data: file name or a list of file names to read.

"""
img_: List = []

filenames: Sequence[str] = ensure_tuple(data)
for name in filenames:
img = self.wsi_reader(name)
if self.reader_lib == "openslide":
img.shape = (img.dimensions[1], img.dimensions[0], 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it support images with channel = 1 or no channel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only supports RGB data as the output and it explicitly convert it to RGB, so the channels always should be 3.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, the doc suggests RGBA (https://openslide.org/api/python/#openslide.OpenSlide.associated_images), I think we should make it clear in the docstring if we only support RGB

img_.append(img)

return img_ if len(filenames) > 1 else img_[0]

def get_data(
self,
img,
location: Tuple[int, int] = (0, 0),
size: Optional[Tuple[int, int]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to use spatial_size, otherwise, it may be confusing whether it contains channel dim.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of agree here but since both OpenSlide and cuClaraImage use size to define this input argument, I found using another name confusing for people who has ML in pathology background and has worked with WSI in the past.

level: int = 0,
dtype: DtypeLike = np.uint8,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[int] = None,
):
"""
Extract regions as numpy array from WSI image and return them.

Args:
img: a WSIReader image object loaded from a file, or list of CuImage objects
location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame,
or list of tuples (default=(0, 0))
size: (height, width) tuple giving the region size, or list of tuples (default to full image size)
This is the size of image at the given level (`level`)
level: the level number, or list of level numbers (default=0)
dtype: the data type of output image
grid_shape: (row, columns) tuple define a grid to extract patches on that
patch_size: (heigsht, width) the size of extracted patches at the given level
"""
if size is None:
if location == (0, 0):
# the maximum size is set to WxH
size = (img.shape[0] // (2 ** level), img.shape[1] // (2 ** level))
print(f"Reading the whole image at level={level} with shape={size}")
else:
raise ValueError("Size need to be provided to extract the region!")

region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype)

metadata: Dict = {}
metadata["spatial_shape"] = size
metadata["original_channel_dim"] = -1
region = EnsureChannelFirst()(region, metadata)

if patch_size is None:
patches = region
else:
patches = self._extract_patches(
region, patch_size=(patch_size, patch_size), grid_shape=grid_shape, dtype=dtype
)

return patches, metadata

def _extract_region(
self,
img_obj,
size: Tuple[int, int],
location: Tuple[int, int] = (0, 0),
level: int = 0,
dtype: DtypeLike = np.uint8,
):
# reverse the order of dimensions for size and location to be compatible with image shape
size = size[::-1]
location = location[::-1]
region = img_obj.read_region(location=location, size=size, level=level)
if self.reader_lib == "openslide":
region = region.convert("RGB")
# convert to numpy
region = np.asarray(region, dtype=dtype)

return region

def _extract_patches(
self,
region: np.ndarray,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[Tuple[int, int]] = None,
dtype: DtypeLike = np.uint8,
):
if patch_size is None and grid_shape == (1, 1):
return region

n_patches = grid_shape[0] * grid_shape[1]
region_size = region.shape[1:]

if patch_size is None:
patch_size = (region_size[0] // grid_shape[0], region_size[1] // grid_shape[1])

# split the region into patches on the grid and center crop them to patch size
flat_patch_grid = np.zeros((n_patches, 3, patch_size[0], patch_size[1]), dtype=dtype)
start_points = [
np.round(region_size[i] * (0.5 + np.arange(grid_shape[i])) / grid_shape[i] - patch_size[i] / 2).astype(int)
for i in range(2)
]
idx = 0
for y_start in start_points[1]:
for x_start in start_points[0]:
x_end = x_start + patch_size[0]
y_end = y_start + patch_size[1]
flat_patch_grid[idx] = region[:, x_start:x_end, y_start:y_end]
idx += 1

return flat_patch_grid
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ Sphinx==3.3.0
recommonmark==0.6.0
sphinx-autodoc-typehints==1.11.1
sphinx-rtd-theme==0.5.0
openslide-python==1.1.2
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ all =
torchvision
itk>=5.0
tqdm>=4.47.0
openslide-python==1.1.2
nibabel =
nibabel
skimage =
Expand All @@ -54,6 +55,8 @@ lmdb =
lmdb
psutil =
psutil
openslide =
openslide-python==1.1.2

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down
103 changes: 103 additions & 0 deletions tests/test_cuimage_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import unittest
from unittest import skipUnless
from urllib import request

import numpy as np
from numpy.testing import assert_array_equal
from parameterized import parameterized

from monai.data.image_reader import WSIReader
from monai.utils import optional_import

_, has_cui = optional_import("cuimage")


FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff"
HEIGHT = 32914
WIDTH = 46000

TEST_CASE_0 = [FILE_URL, (3, HEIGHT, WIDTH)]

TEST_CASE_1 = [
FILE_URL,
{"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0},
np.array([[[246], [246]], [[246], [246]], [[246], [246]]]),
]

TEST_CASE_2 = [
FILE_URL,
{"location": (0, 0), "size": (2, 1), "level": 2},
np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
]

TEST_CASE_3 = [
FILE_URL,
{
"location": (0, 0),
"size": (8, 8),
"level": 2,
"grid_shape": (2, 1),
"patch_size": 2,
},
np.array(
[
[[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]],
[[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]],
]
),
]

TEST_CASE_4 = [
FILE_URL,
{
"location": (0, 0),
"size": (8, 8),
"level": 2,
"grid_shape": (2, 1),
"patch_size": 1,
},
np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]),
]


class TestCuClaraImageReader(unittest.TestCase):
@parameterized.expand([TEST_CASE_0])
@skipUnless(has_cui, "Requires CuClaraImage")
def test_read_whole_image(self, file_url, expected_shape):
filename = self.camelyon_data_download(file_url)
reader = WSIReader("CuClaraImage")
img_obj = reader.read(filename)
img = reader.get_data(img_obj)[0]
self.assertTupleEqual(img.shape, expected_shape)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@skipUnless(has_cui, "Requires CuClaraImage")
def test_read_region(self, file_url, patch_info, expected_img):
filename = self.camelyon_data_download(file_url)
reader = WSIReader("CuClaraImage")
img_obj = reader.read(filename)
img = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))

@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
@skipUnless(has_cui, "Requires CuClaraImage")
def test_read_patches(self, file_url, patch_info, expected_img):
filename = self.camelyon_data_download(file_url)
reader = WSIReader("CuClaraImage")
img_obj = reader.read(filename)
img = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))

def camelyon_data_download(self, file_url):
filename = os.path.basename(file_url)
if not os.path.exists(filename):
print(f"Test image [{filename}] does not exist. Downloading...")
request.urlretrieve(file_url, filename)
return filename


if __name__ == "__main__":
unittest.main()
Loading