Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/apps/pathology/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class SplitOnGrid(Transform):
If it's an integer, the value will be repeated for each dimension. Default is 2x2
patch_size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is (0, 0), where the patch size will be infered from the grid shape.
The default is (0, 0), where the patch size will be inferred from the grid shape.

Note: the shape of the input image is infered based on the first image used.
Note: the shape of the input image is inferred based on the first image used.
"""

def __init__(
Expand Down
49 changes: 31 additions & 18 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,15 @@ class WSIReader(ImageReader):
Read whole slide images and extract patches.

Args:
backend: backend library to load the images, available options: "OpenSlide" or "cuCIM".
backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "Tifffile".
level: the whole slide image level at which the image is extracted. (default=0)
Note that this is overridden by the level argument in `get_data`.
This is overridden if the level argument is provided in `get_data`.

Note:
While "cucim" and "OpenSlide" backends both can load patches from large whole slide images
without loading the entire image into memory, "Tifffile" backend needs to load the entire image into memory
before extracting any patch; thus, memory consideration is needed when using "Tifffile" backend for
patch extraction.
"""

def __init__(self, backend: str = "OpenSlide", level: int = 0):
Expand All @@ -689,8 +694,10 @@ def __init__(self, backend: str = "OpenSlide", level: int = 0):
self.wsi_reader, *_ = optional_import("openslide", name="OpenSlide")
elif self.backend == "cucim":
self.wsi_reader, *_ = optional_import("cucim", name="CuImage")
elif self.backend == "tifffile":
self.wsi_reader, *_ = optional_import("tifffile", name="TiffFile")
else:
raise ValueError('`backend` should be either "cuCIM" or "OpenSlide"')
raise ValueError('`backend` should be "cuCIM", "OpenSlide", or "TiffFile')
self.level = level

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
Expand Down Expand Up @@ -780,13 +787,22 @@ def _extract_region(
level: int = 0,
dtype: DtypeLike = np.uint8,
):
# reverse the order of dimensions for size and location to be compatible with image shape
location = location[::-1]
if size is None:
region = img_obj.read_region(location=location, level=level)
if self.backend == "tifffile":
with img_obj:
region = img_obj.asarray(level=level)
if size is None:
region = region[location[0] :, location[1] :]
else:
region = region[location[0] : location[0] + size[0], location[1] : location[1] + size[1]]

else:
size = size[::-1]
region = img_obj.read_region(location=location, size=size, level=level)
# reverse the order of dimensions for size and location to be compatible with image shape
location = location[::-1]
if size is None:
region = img_obj.read_region(location=location, level=level)
else:
size = size[::-1]
region = img_obj.read_region(location=location, size=size, level=level)

region = self.convert_to_rgb_array(region, dtype)
return region
Expand All @@ -796,15 +812,12 @@ def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8):
if self.backend == "openslide":
# convert to RGB
raw_region = raw_region.convert("RGB")
# convert to numpy
raw_region = np.asarray(raw_region, dtype=dtype)
else:
num_channels = len(raw_region.channel_names)
# convert to numpy
raw_region = np.asarray(raw_region, dtype=dtype)
# remove alpha channel if exist (RGBA)
if num_channels > 3:
raw_region = raw_region[:, :, :3]

# convert to numpy (if not already in numpy)
raw_region = np.asarray(raw_region, dtype=dtype)
# remove alpha channel if exist (RGBA)
if raw_region.shape[-1] > 3:
raw_region = raw_region[..., :3]

return raw_region

Expand Down
9 changes: 9 additions & 0 deletions tests/test_wsireader.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def test_read_whole_image(self, file_path, level, expected_shape):
def test_read_region(self, file_path, patch_info, expected_img):
reader = WSIReader(self.backend)
img_obj = reader.read(file_path)
# Read twice to check multiple calls
img = reader.get_data(img_obj, **patch_info)[0]
img = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))
Expand Down Expand Up @@ -171,5 +173,12 @@ def setUpClass(cls):
cls.backend = "openslide"


@skipUnless(has_tiff, "Requires TiffFile")
class TestTiffFile(WSIReaderTests.Tests):
@classmethod
def setUpClass(cls):
cls.backend = "tifffile"


if __name__ == "__main__":
unittest.main()