Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2baa344
Redesign BaseWSIReader, WSIReader, CuCIMWSIReader
bhashemian Apr 12, 2022
e46fea0
Add unittests for WSIReader
bhashemian Apr 12, 2022
f946a63
Add image mode for output validation
bhashemian Apr 12, 2022
9b41019
Merge branch 'dev' into new-wsireader
bhashemian Apr 12, 2022
9c15ea3
Update docs
bhashemian Apr 13, 2022
e004538
Update references to new WSIReader
bhashemian Apr 13, 2022
df0a61e
Remove legacy WSIReader
bhashemian Apr 13, 2022
b64c087
Update unittests
bhashemian Apr 13, 2022
7f09e73
Update docs
bhashemian Apr 13, 2022
d95e1c2
sort imports
bhashemian Apr 13, 2022
025376b
Clean up imports
bhashemian Apr 13, 2022
3e6fb27
Update docstrings
bhashemian Apr 13, 2022
b8a2444
Update docs and docstrings
bhashemian Apr 13, 2022
0268a03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2022
f2b14aa
Fix a typo
bhashemian Apr 13, 2022
aed0ac5
Remove redundant checking
bhashemian Apr 13, 2022
272d11f
Merge branch 'new-wsireader' of github.com:drbeh/MONAI into new-wsire…
bhashemian Apr 13, 2022
ae4ff17
Update read and other methods
bhashemian Apr 14, 2022
e0bad98
Merge branch 'pathology' of github.com:Project-MONAI/MONAI into new-w…
bhashemian Apr 14, 2022
e5a7a18
Update wsireader to support multi image and update docstrings
bhashemian Apr 18, 2022
da6b675
Make workaround for CuImage objects
bhashemian Apr 18, 2022
1ed8154
Add unittests for multi image reading
bhashemian Apr 18, 2022
77f8747
Update a note about cucim
bhashemian Apr 18, 2022
36e5b17
Update type hints and docstrings
bhashemian Apr 19, 2022
0328fe6
Merge branch 'pathology' of github.com:Project-MONAI/MONAI into new-w…
bhashemian Apr 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,24 @@ PILReader
.. autoclass:: PILReader
:members:

Whole slide image reader
------------------------

BaseWSIReader
~~~~~~~~~~~~~
.. autoclass:: BaseWSIReader
:members:

WSIReader
~~~~~~~~~
.. autoclass:: WSIReader
:members:

CuCIMWSIReader
~~~~~~~~~~~~~~
.. autoclass:: CuCIMWSIReader
:members:

Image writer
------------

Expand Down
2 changes: 1 addition & 1 deletion monai/apps/pathology/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from monai.data import Dataset, SmartCacheDataset
from monai.data.image_reader import WSIReader
from monai.data.wsi_reader import WSIReader
from monai.utils import ensure_tuple_rep

__all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset", "MaskedInferenceWSIDataset"]
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/pathology/metrics/lesion_froc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

from monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask
from monai.data.image_reader import WSIReader
from monai.data.wsi_reader import WSIReader
from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
from monai.utils import min_version, optional_import

Expand Down
3 changes: 2 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .folder_layout import FolderLayout
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from .image_writer import (
SUPPORTED_WRITERS,
ImageWriter,
Expand Down Expand Up @@ -87,3 +87,4 @@
worker_init_fn,
zoom_affine,
)
from .wsi_reader import BaseWSIReader, CuCIMWSIReader, WSIReader
267 changes: 2 additions & 265 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

from monai.config import DtypeLike, KeysCollection, PathLike
from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format, orientation_ras_lps
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg
from monai.utils import ensure_tuple, optional_import, require_pkg

if TYPE_CHECKING:
import itk
Expand All @@ -39,7 +38,7 @@
CuImage, _ = optional_import("cucim", name="CuImage")
TiffFile, _ = optional_import("tifffile", name="TiffFile")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"]
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"]


class ImageReader(ABC):
Expand Down Expand Up @@ -714,265 +713,3 @@ def _get_spatial_shape(self, img):
img: a PIL Image object loaded from an image file.
"""
return np.asarray((img.width, img.height))


class WSIReader(ImageReader):
"""
Read whole slide images and extract patches.

Args:
backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile".
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
kwargs: additional args for backend reading API in `read()`, more details in `cuCIM`, `TiffFile`, `OpenSlide`:
https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100.
https://github.com/cgohlke/tifffile.
https://openslide.org/api/python/#openslide.OpenSlide.

Note:
While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images
without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory
before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for
patch extraction.

"""

def __init__(self, backend: str = "OpenSlide", level: int = 0, **kwargs):
super().__init__()
self.backend = backend.lower()
func = require_pkg(self.backend)(self._set_reader)
self.wsi_reader = func(self.backend)
self.level = level
self.kwargs = kwargs

@staticmethod
def _set_reader(backend: str):
if backend == "openslide":
return OpenSlide
if backend == "cucim":
return CuImage
if backend == "tifffile":
return TiffFile
raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.")

def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
"""
Verify whether the specified file or files format is supported by WSI reader.

Args:
filename: file name or a list of file names to read.
if a list of files, verify all the suffixes.
"""
return is_supported_format(filename, ["tif", "tiff"])

def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
Read image data from given file or list of files.

Args:
data: file name or a list of file names to read.
kwargs: additional args for backend reading API in `read()`, will override `self.kwargs` for existing keys.
more details in `cuCIM`, `TiffFile`, `OpenSlide`:
https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100.
https://github.com/cgohlke/tifffile.
https://openslide.org/api/python/#openslide.OpenSlide.

Returns:
image object or list of image objects

"""
img_: List = []

filenames: Sequence[PathLike] = ensure_tuple(data)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
img = self.wsi_reader(name, **kwargs_)
if self.backend == "openslide":
img.shape = (img.dimensions[1], img.dimensions[0], 3)
img_.append(img)

return img_ if len(filenames) > 1 else img_[0]

def get_data(
self,
img,
location: Tuple[int, int] = (0, 0),
size: Optional[Tuple[int, int]] = None,
level: Optional[int] = None,
dtype: DtypeLike = np.uint8,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
):
"""
Extract regions as numpy array from WSI image and return them.

Args:
img: a WSIReader image object loaded from a file, or list of CuImage objects
location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame,
or list of tuples (default=(0, 0))
size: (height, width) tuple giving the region size, or list of tuples (default to full image size)
This is the size of image at the given level (`level`)
level: the level number, or list of level numbers (default=0)
dtype: the data type of output image
grid_shape: (row, columns) tuple define a grid to extract patches on that
patch_size: (height, width) the size of extracted patches at the given level
"""
# Verify inputs
if level is None:
level = self.level
max_level = self._get_max_level(img)
if level > max_level:
raise ValueError(f"The maximum level of this image is {max_level} while level={level} is requested)!")

# Extract a region or the entire image
region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype)

# Add necessary metadata
metadata: Dict = {}
metadata["spatial_shape"] = np.asarray(region.shape[:-1])
metadata["original_channel_dim"] = -1

# Make it channel first
region = EnsureChannelFirst()(region, metadata)

# Split into patches
if patch_size is None:
patches = region
else:
tuple_patch_size = ensure_tuple_rep(patch_size, 2)
patches = self._extract_patches(
region, patch_size=tuple_patch_size, grid_shape=grid_shape, dtype=dtype # type: ignore
)

return patches, metadata

def _get_max_level(self, img_obj):
"""
Return the maximum number of levels in the whole slide image
Args:
img: the whole slide image object

"""
if self.backend == "openslide":
return img_obj.level_count - 1
if self.backend == "cucim":
return img_obj.resolutions["level_count"] - 1
if self.backend == "tifffile":
return len(img_obj.pages) - 1

def _get_image_size(self, img, size, level, location):
"""
Calculate the maximum region size for the given level and starting location (if size is None).
Note that region size in OpenSlide and cuCIM are WxH (but the final image output would be HxW)
"""
if size is not None:
return size[::-1]

max_size = []
downsampling_factor = []
if self.backend == "openslide":
downsampling_factor = img.level_downsamples[level]
max_size = img.level_dimensions[level]
elif self.backend == "cucim":
downsampling_factor = img.resolutions["level_downsamples"][level]
max_size = img.resolutions["level_dimensions"][level]

# subtract the top left corner of the patch (at given level) from maximum size
location_at_level = (round(location[1] / downsampling_factor), round(location[0] / downsampling_factor))
size = [max_size[i] - location_at_level[i] for i in range(len(max_size))]

return size

def _extract_region(
self,
img_obj,
size: Optional[Tuple[int, int]],
location: Tuple[int, int] = (0, 0),
level: int = 0,
dtype: DtypeLike = np.uint8,
):
if self.backend == "tifffile":
# Read the entire image
if size is not None:
raise ValueError(
f"TiffFile backend reads the entire image only, so size '{size}'' should not be provided!",
"For more flexibility or extracting regions, please use cuCIM or OpenSlide backend.",
)
if location != (0, 0):
raise ValueError(
f"TiffFile backend reads the entire image only, so location '{location}' should not be provided!",
"For more flexibility and extracting regions, please use cuCIM or OpenSlide backend.",
)
region = img_obj.asarray(level=level)
else:
# Get region size to be extracted
region_size = self._get_image_size(img_obj, size, level, location)
# reverse the order of location's dimensions to become WxH (for cuCIM and OpenSlide)
region_location = location[::-1]
# Extract a region (or the entire image)
region = img_obj.read_region(location=region_location, size=region_size, level=level)

region = self.convert_to_rgb_array(region, dtype)
return region

def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8):
"""Convert to RGB mode and numpy array"""
if self.backend == "openslide":
# convert to RGB
raw_region = raw_region.convert("RGB")

# convert to numpy (if not already in numpy)
raw_region = np.asarray(raw_region, dtype=dtype)

# check if the image has three dimensions (2D + color)
if raw_region.ndim != 3:
raise ValueError(
f"The input image dimension should be 3 but {raw_region.ndim} is given. "
"`WSIReader` is designed to work only with 2D colored images."
)

# check if the color channel is 3 (RGB) or 4 (RGBA)
if raw_region.shape[-1] not in [3, 4]:
raise ValueError(
f"There should be three or four color channels but {raw_region.shape[-1]} is given. "
"`WSIReader` is designed to work only with 2D colored images."
)

# remove alpha channel if exist (RGBA)
if raw_region.shape[-1] > 3:
raw_region = raw_region[..., :3]

return raw_region

def _extract_patches(
self,
region: np.ndarray,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[Tuple[int, int]] = None,
dtype: DtypeLike = np.uint8,
):
if patch_size is None and grid_shape == (1, 1):
return region

n_patches = grid_shape[0] * grid_shape[1]
region_size = region.shape[1:]

if patch_size is None:
patch_size = (region_size[0] // grid_shape[0], region_size[1] // grid_shape[1])

# split the region into patches on the grid and center crop them to patch size
flat_patch_grid = np.zeros((n_patches, 3, patch_size[0], patch_size[1]), dtype=dtype)
start_points = [
np.round(region_size[i] * (0.5 + np.arange(grid_shape[i])) / grid_shape[i] - patch_size[i] / 2).astype(int)
for i in range(2)
]
idx = 0
for y_start in start_points[1]:
for x_start in start_points[0]:
x_end = x_start + patch_size[0]
y_end = y_start + patch_size[1]
flat_patch_grid[idx] = region[:, x_start:x_end, y_start:y_end]
idx += 1

return flat_patch_grid
Loading