Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
7f221ff
Implement SlidingPatchWSIDataset
bhashemian May 2, 2022
b011f41
Update docs
bhashemian May 2, 2022
2f7c881
Merge branch 'dev' of github.com:Project-MONAI/MONAI into sliding-patch
bhashemian May 6, 2022
3830551
Fix a typo in skip message
bhashemian May 6, 2022
ba20427
Update patch/steps at different levels
bhashemian May 6, 2022
1db779c
Update to sliding_sample
bhashemian May 6, 2022
0fc37c3
Implement unittests for SlidingPatchWSIReader
bhashemian May 6, 2022
114c234
Merge dev branch
bhashemian May 6, 2022
577b097
Remove unused import
bhashemian May 6, 2022
9cd07bb
Minor updates and add openslide tests
bhashemian May 9, 2022
465ecf5
Merge branch 'dev' into sliding-patch
bhashemian May 9, 2022
09a5e1a
Merge dev
bhashemian May 16, 2022
f206ecd
Update docstring
bhashemian May 16, 2022
cbd477a
Implement iter_wsi_patch_location
bhashemian May 17, 2022
aba82b3
Update docstring
bhashemian May 17, 2022
5171287
Uncomment openslide
bhashemian May 17, 2022
71e70c3
Reorder imports
bhashemian May 17, 2022
2e40545
Add overlap and more updates
bhashemian May 17, 2022
b733ec3
Remove product
bhashemian May 17, 2022
1aefdd1
Add overlap to iter_patch
bhashemian May 17, 2022
c57bec4
Update docstring
bhashemian May 17, 2022
f2a3a22
fix an arg
bhashemian May 18, 2022
0ab908d
Merge branch 'dev' into sliding-patch
bhashemian May 18, 2022
a920089
Merge branch 'dev' into sliding-patch
bhashemian May 19, 2022
175b6bd
Merge branch 'dev' of github.com:Project-MONAI/MONAI into sliding-patch
bhashemian May 20, 2022
03b9dcf
Merge branch 'dev' into sliding-patch
bhashemian May 22, 2022
dd16102
Merge branch 'dev' of github.com:Project-MONAI/MONAI into sliding-patch
bhashemian May 23, 2022
cc043d4
Include padded patches
bhashemian May 23, 2022
a0e6099
Merge branch 'sliding-patch' of github.com:behxyz/MONAI into sliding-…
bhashemian May 23, 2022
1700fda
Merge branch 'dev' into sliding-patch
bhashemian May 23, 2022
0128c3c
Separate overlap for each dimension
bhashemian May 23, 2022
30ed298
Merge branch 'sliding-patch' of github.com:behxyz/MONAI into sliding-…
bhashemian May 23, 2022
5466063
Fix docstring issue
bhashemian May 24, 2022
eb7d3f2
Remove patch number
bhashemian May 24, 2022
2d0ecbb
Merge branch 'sliding-patch' of github.com:behxyz/MONAI into sliding-…
bhashemian May 24, 2022
f6cd795
Change to get_downsample_ratio
bhashemian May 24, 2022
e5cb3ab
minor fixes
bhashemian May 24, 2022
9578c4d
Merge branch 'dev' into sliding-patch
bhashemian May 24, 2022
7437854
Implemenet random offset
bhashemian May 24, 2022
441abd2
Combine iter_patch_slices and iter_wsi_patch_locations
bhashemian May 24, 2022
d998eb5
Update iter_patch_slices
bhashemian May 25, 2022
31b5fbc
Merge branch 'dev' into sliding-patch
bhashemian May 25, 2022
5447128
Move iter_patch_slices and fix docstrings
bhashemian May 25, 2022
ebe7143
Merge branch 'dev' into sliding-patch
bhashemian May 25, 2022
3ebd8c1
Update init
bhashemian May 25, 2022
b520d68
Merge branch 'dev' into sliding-patch
bhashemian May 25, 2022
e219e58
Fixed randomness
bhashemian May 25, 2022
835ce91
Merge branch 'dev' into sliding-patch
wyli May 26, 2022
971d91a
Merge branch 'dev' into sliding-patch
wyli May 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ PatchWSIDataset
.. autoclass:: monai.data.PatchWSIDataset
:members:

SlidingPatchWSIDataset
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.SlidingPatchWSIDataset
:members:

Bounding box
------------
.. automodule:: monai.data.box_utils
Expand Down
3 changes: 2 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
get_valid_patch_size,
is_supported_format,
iter_patch,
iter_patch_position,
iter_patch_slices,
json_hashing,
list_data_collate,
Expand All @@ -103,7 +104,7 @@
worker_init_fn,
zoom_affine,
)
from .wsi_datasets import PatchWSIDataset
from .wsi_datasets import PatchWSIDataset, SlidingPatchWSIDataset
from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, WSIReader

with contextlib.suppress(BaseException):
Expand Down
1 change: 1 addition & 0 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __call__(self, array: np.ndarray):
array,
patch_size=self.patch_size, # type: ignore
start_pos=self.start_pos,
overlap=0.0,
copy_back=False,
mode=self.mode,
**self.pad_opts,
Expand Down
81 changes: 66 additions & 15 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"get_valid_patch_size",
"is_supported_format",
"iter_patch",
"iter_patch_position",
"iter_patch_slices",
"json_hashing",
"list_data_collate",
Expand Down Expand Up @@ -123,32 +124,36 @@ def get_random_patch(


def iter_patch_slices(
dims: Sequence[int], patch_size: Union[Sequence[int], int], start_pos: Sequence[int] = ()
image_size: Sequence[int],
patch_size: Union[Sequence[int], int],
start_pos: Sequence[int] = (),
overlap: Union[Sequence[float], float] = 0.0,
padded: bool = True,
) -> Generator[Tuple[slice, ...], None, None]:
"""
Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `dims`. The
iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each
patch is chosen in a contiguous grid using a first dimension as least significant ordering.
Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `image_size`.
The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each
patch is chosen in a contiguous grid using a rwo-major ordering.

Args:
dims: dimensions of array to iterate over
image_size: dimensions of array to iterate over
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
padded: if the image is padded so the patches can go beyond the borders. Defaults to False.

Yields:
Tuples of slice objects defining each patch
"""

# ensure patchSize and startPos are the right length
ndim = len(dims)
patch_size_ = get_valid_patch_size(dims, patch_size)
start_pos = ensure_tuple_size(start_pos, ndim)
# ensure patch_size has the right length
patch_size_ = get_valid_patch_size(image_size, patch_size)

# collect the ranges to step over each dimension
ranges = tuple(starmap(range, zip(start_pos, dims, patch_size_)))

# choose patches by applying product to the ranges
for position in product(*ranges):
# create slices based on start position of each patch
for position in iter_patch_position(
image_size=image_size, patch_size=patch_size_, start_pos=start_pos, overlap=overlap, padded=padded
):
yield tuple(slice(s, s + p) for s, p in zip(position, patch_size_))


Expand Down Expand Up @@ -192,10 +197,54 @@ def dense_patch_slices(
return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]


def iter_patch_position(
image_size: Sequence[int],
patch_size: Union[Sequence[int], int],
start_pos: Sequence[int] = (),
overlap: Union[Sequence[float], float] = 0.0,
padded: bool = False,
):
"""
Yield successive tuples of upper left corner of patches of size `patch_size` from an array of dimensions `image_size`.
The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each
patch is chosen in a contiguous grid using a rwo-major ordering.

Args:
image_size: dimensions of array to iterate over
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
padded: if the image is padded so the patches can go beyond the borders. Defaults to False.

Yields:
Tuples of positions defining the upper left corner of each patch
"""

# ensure patchSize and startPos are the right length
ndim = len(image_size)
patch_size_ = get_valid_patch_size(image_size, patch_size)
start_pos = ensure_tuple_size(start_pos, ndim)
overlap = ensure_tuple_rep(overlap, ndim)

# calculate steps, which depends on the amount of overlap
steps = tuple(round(p * (1.0 - o)) for p, o in zip(patch_size_, overlap))

# calculate the last starting location (depending on the padding)
end_pos = image_size if padded else tuple(s - round(p) + 1 for s, p in zip(image_size, patch_size_))

# collect the ranges to step over each dimension
ranges = starmap(range, zip(start_pos, end_pos, steps))

# choose patches by applying product to the ranges
return product(*ranges)


def iter_patch(
arr: np.ndarray,
patch_size: Union[Sequence[int], int] = 0,
start_pos: Sequence[int] = (),
overlap: Union[Sequence[float], float] = 0.0,
copy_back: bool = True,
mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP,
**pad_opts: Dict,
Expand All @@ -209,6 +258,8 @@ def iter_patch(
arr: array to iterate over
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
Expand Down Expand Up @@ -243,7 +294,7 @@ def iter_patch(
# patches which are only in the padded regions
iter_size = tuple(s + p for s, p in zip(arr.shape, patch_size_))

for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded):
for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded, overlap):
# compensate original image padding
coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, patch_size_))
yield arrpad[slices], np.asarray(coords_no_pad) # data and coords (in numpy; works with torch loader)
Expand Down
132 changes: 130 additions & 2 deletions monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import numpy as np

from monai.data import Dataset
from monai.data.utils import iter_patch_position
from monai.data.wsi_reader import BaseWSIReader, WSIReader
from monai.transforms import apply_transform
from monai.transforms import Randomizable, apply_transform
from monai.utils import ensure_tuple_rep

__all__ = ["PatchWSIDataset"]
__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset"]


class PatchWSIDataset(Dataset):
Expand Down Expand Up @@ -137,3 +138,130 @@ def _transform(self, index: int):
# Apply transforms and output
output = {"image": image, "label": label, "metadata": metadata}
return apply_transform(self.transform, output) if self.transform else output


class SlidingPatchWSIDataset(Randomizable, PatchWSIDataset):
"""
This dataset extracts patches from whole slide images (without loading the whole image)
It also reads labels for each patch and provides each patch with its associated class labels.

Args:
data: the list of input samples including image, location, and label (see the note below for more details).
size: the size of patch to be extracted from the whole slide image.
level: the level at which the patches to be extracted (default to 0).
offset: the offset of image to extract patches (the starting position of the upper left patch).
offset_limits: if offset is set to "random", a tuple of integers defining the lower and upper limit of the
random offset for all dimensions, or a tuple of tuples that defines the limits for each dimension.
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
transform: transforms to be executed on input data.
reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is

- a string, it defines the backend of `monai.data.WSIReader`.
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,
- an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.

seed: random seed to randomly generate offsets. Defaults to 0.
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class

Note:
The input data has the following form as an example:

.. code-block:: python

[
{"image": "path/to/image1.tiff"},
{"image": "path/to/image2.tiff", "size": [20, 20], "level": 2}
]

"""

def __init__(
self,
data: Sequence,
size: Optional[Union[int, Tuple[int, int]]] = None,
level: Optional[int] = None,
overlap: Union[Tuple[float, float], float] = 0.0,
offset: Union[Tuple[int, int], int, str] = (0, 0),
offset_limits: Optional[Union[Tuple[Tuple[int, int], Tuple[int, int]], Tuple[int, int]]] = None,
transform: Optional[Callable] = None,
reader="cuCIM",
seed: int = 0,
**kwargs,
):
super().__init__(data=data, size=size, level=level, transform=transform, reader=reader, **kwargs)
self.overlap = overlap
self.set_random_state(seed)
# Set the offset config
self.random_offset = False
if isinstance(offset, str):
if offset == "random":
self.random_offset = True
self.offset_limits: Optional[Tuple[Tuple[int, int], Tuple[int, int]]]
if offset_limits is None:
self.offset_limits = None
elif isinstance(offset_limits, tuple):
if isinstance(offset_limits[0], int):
self.offset_limits = (offset_limits, offset_limits)
elif isinstance(offset_limits[0], tuple):
self.offset_limits = offset_limits
else:
ValueError(
"The offset limits should be either a tuple of integers or tuple of tuple of integers."
)
else:
ValueError("The offset limits should be a tuple.")
else:
ValueError(
f'Invalid string for offset "{offset}". It should be either "random" as a string,'
"an integer, or a tuple of integers defining the offset."
)
else:
self.offset = ensure_tuple_rep(offset, 2)

# Create single sample for each patch (in a sliding window manner)
self.data = []
for sample in data:
sliding_samples = self._evaluate_patch_coordinates(sample)
self.data.extend(sliding_samples)

def _get_offset(self, sample):
if self.random_offset:
if self.offset_limits is None:
offset_limits = tuple((-s, s) for s in self._get_size(sample))
else:
offset_limits = self.offset_limits
return tuple(self.R.randint(low, high) for low, high in offset_limits)
return self.offset

def _evaluate_patch_coordinates(self, sample):
"""Define the location for each patch based on sliding-window approach"""
patch_size = self._get_size(sample)
level = self._get_level(sample)
start_pos = self._get_offset(sample)

wsi_obj = self._get_wsi_object(sample)
wsi_size = self.wsi_reader.get_size(wsi_obj, 0)
downsample = self.wsi_reader.get_downsample_ratio(wsi_obj, level)
patch_size_ = tuple(p * downsample for p in patch_size) # patch size at level 0
locations = list(
iter_patch_position(
image_size=wsi_size, patch_size=patch_size_, start_pos=start_pos, overlap=self.overlap, padded=False
)
)
sample["size"] = patch_size
sample["level"] = level
n_patches = len(locations)
return [{**sample, "location": loc, "num_patches": n_patches} for loc in locations]

def _get_location(self, sample: Dict):
return sample["location"]

def _transform(self, index: int):
# Get a single entry of data
sample: Dict = self.data[index]
# Extract patch image and associated metadata
image, metadata = self._get_data(sample)
# Create put all patch information together and apply transforms
patch = {"image": image, "metadata": metadata}
return apply_transform(self.transform, patch) if self.transform else patch
47 changes: 47 additions & 0 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def get_level_count(self, wsi) -> int:
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_downsample_ratio(self, wsi, level: int) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
level: the level number where the size is calculated

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

@abstractmethod
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
Expand Down Expand Up @@ -290,6 +302,17 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]:
"""
return self.reader.get_size(wsi, level)

def get_downsample_ratio(self, wsi, level: int) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
level: the level number where the size is calculated

"""
return self.reader.get_downsample_ratio(wsi, level)

def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return self.reader.get_file_path(wsi)
Expand Down Expand Up @@ -369,6 +392,18 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
"""
return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0])

@staticmethod
def get_downsample_ratio(wsi, level: int) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
level: the level number where the size is calculated

"""
return wsi.resolutions["level_downsamples"][level] # type: ignore

def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi.path))
Expand Down Expand Up @@ -475,6 +510,18 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
"""
return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0])

@staticmethod
def get_downsample_ratio(wsi, level: int) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.

Args:
wsi: a whole slide image object loaded from a file
level: the level number where the size is calculated

"""
return wsi.level_downsamples[level] # type: ignore

def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi._filename))
Expand Down
Loading