From 6eb6eeb1e788aa1eb1defd785f081e5e34bdbeae Mon Sep 17 00:00:00 2001 From: myron Date: Mon, 1 Nov 2021 17:53:58 -0700 Subject: [PATCH 1/5] MIL component to extract patches Signed-off-by: myron --- monai/apps/pathology/transforms/__init__.py | 4 +- .../pathology/transforms/spatial/__init__.py | 4 +- .../pathology/transforms/spatial/array.py | 164 +++++++++++++++++- .../transforms/spatial/dictionary.py | 84 ++++++++- tests/test_tile_on_grid.py | 110 ++++++++++++ tests/test_tile_on_grid_dict.py | 123 +++++++++++++ 6 files changed, 478 insertions(+), 11 deletions(-) create mode 100644 tests/test_tile_on_grid.py create mode 100644 tests/test_tile_on_grid_dict.py diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 1be96b8e34..b418e20279 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .spatial.array import SplitOnGrid -from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .spatial.array import SplitOnGrid, TileOnGrid +from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains from .stain.dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py index 07ba222ab0..c9971254e7 100644 --- a/monai/apps/pathology/transforms/spatial/__init__.py +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .array import SplitOnGrid -from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .array import SplitOnGrid, TileOnGrid +from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 4edf987610..94ab1e2da5 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Tuple, Union +import numpy as np import torch +from numpy.lib.stride_tricks import as_strided -from monai.transforms.transform import Transform +from monai.transforms.transform import Randomizable, Transform -__all__ = ["SplitOnGrid"] +__all__ = ["SplitOnGrid", "TileOnGrid"] class SplitOnGrid(Transform): @@ -73,3 +75,159 @@ def get_params(self, image_size): ) return patch_size, steps + + +class TileOnGrid(Randomizable, Transform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None Extract all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to None (same as tile_size) + random_offset: Randomize position of tile grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset + Defaults to ``min`` (which assumes background is white, high value) + + """ + + def __init__( + self, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: Optional[str] = "min", + ): + self.tile_count = tile_count + self.tile_size = tile_size + self.step = step + self.random_offset = random_offset + self.pad_full = pad_full + self.background_val = background_val + self.filter_mode = filter_mode + + # self.tile_all = (self.tile_count is None) + + # if self.tile_count is None: + # self.tile_count = max((44 * 256 ** 2) // (tile_size ** 2), 1) + + if self.step is None: + self.step = self.tile_size # non-overlapping grid + + self.offset = (0, 0) + self.random_idxs = [0] + + def randomize(self, img_size: Sequence[int]) -> None: + + c, h, w = img_size + # tile_count: int = self.tile_count # type: ignore + tile_step: int = self.step # type: ignore + + if self.random_offset: + pad_h = h % self.tile_size + pad_w = w % self.tile_size + if pad_h > 0 and pad_w > 0: + self.offset = (self.R.randint(pad_h), self.R.randint(pad_w)) + h = h - self.offset[0] + w = w - self.offset[1] + else: + self.offset = (0, 0) + + if self.pad_full: + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + h = h + pad_h + w = w + pad_w + + h_n = (h - self.tile_size + tile_step) // tile_step + w_n = (w - self.tile_size + tile_step) // tile_step + tile_total = h_n * w_n + + if self.tile_count is not None and tile_total > self.tile_count: + self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) # type: ignore + else: + self.random_idxs = [0] # type: ignore + + def __call__(self, image: np.ndarray) -> np.ndarray: + + # add random offset + self.randomize(img_size=image.shape) + # tile_count: int = self.tile_count # type: ignore + tile_step: int = self.step # type: ignore + + if self.random_offset and self.offset is not None: + image = image[:, self.offset[0] :, self.offset[1] :] + + # pad to full size, divisible by tile_size + if self.pad_full: + c, h, w = image.shape + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + image = np.pad( + image, + [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], + constant_values=self.background_val, + ) + + # extact tiles (new way) + xstep, ystep = tile_step, tile_step + xsize, ysize = self.tile_size, self.tile_size + clen, xlen, ylen = image.shape + cstride, xstride, ystride = image.strides + llw = as_strided( + image, + shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize), + strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride), + writeable=False, + ) + image = llw.reshape(-1, clen, xsize, ysize) + + # if keep all patches + if self.tile_count is None: + # retain only patches with significant foreground content to speed up inference + # FYI, this returns a variable number of tiles, so the batch_size much be 1 (per gpu). Used during inference + thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size + if self.filter_mode == "min": + # default, keep non-background tiles (small values) + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh) + image = image[idxs.reshape(-1)] + elif self.filter_mode == "max": + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh) + image = image[idxs.reshape(-1)] + + else: + if len(image) >= self.tile_count: + + if self.filter_mode == "min": + # default, keep non-background tiles (smallest values) + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count] + image = image[idxs] + elif self.filter_mode == "max": + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :] + image = image[idxs] + elif len(image) > self.tile_count: + # random subset (more appropriate for WSIs without distinct background) + if self.random_idxs is not None: + image = image[self.random_idxs] + + else: + image = np.pad( + image, + [[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]], + constant_values=self.background_val, + ) + + return image diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 10b01a39de..12d4e973e8 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -9,16 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Mapping, Optional, Tuple, Union +import copy +from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union +import numpy as np import torch from monai.config import KeysCollection -from monai.transforms.transform import MapTransform +from monai.transforms.transform import MapTransform, Randomizable -from .array import SplitOnGrid +from .array import SplitOnGrid, TileOnGrid -__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"] +__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"] class SplitOnGridd(MapTransform): @@ -53,4 +55,78 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class TileOnGridd(Randomizable, MapTransform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None Extract all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to None (same as tile_size) + random_offset: Randomize position of tile grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset + Defaults to ``min`` (which assumes background is white, high value) + + """ + + def __init__( + self, + keys: KeysCollection, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: Optional[str] = "min", + allow_missing_keys: bool = False, + return_list_of_dicts: bool = False, + ): + super().__init__(keys, allow_missing_keys) + + self.return_list_of_dicts = return_list_of_dicts + self.seed = None + + self.splitter = TileOnGrid( + tile_count=tile_count, + tile_size=tile_size, + step=step, + random_offset=random_offset, + pad_full=pad_full, + background_val=background_val, + filter_mode=filter_mode, + ) + + def randomize(self, data: Any = None) -> None: + self.seed = self.R.randint(10000) # type: ignore + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]: + + self.randomize() + + d = dict(data) + for key in self.key_iterator(d): + self.splitter.set_random_state(seed=self.seed) # same random seed for all keys + d[key] = self.splitter(d[key]) + + if self.return_list_of_dicts: + d_list = [] + for i in range(len(d[self.keys[0]])): + d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()}) + d = d_list # type: ignore + + return d + + SplitOnGridDict = SplitOnGridD = SplitOnGridd +TileOnGridDict = TileOnGridD = TileOnGridd diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py new file mode 100644 index 0000000000..a35bf4db68 --- /dev/null +++ b/tests/test_tile_on_grid.py @@ -0,0 +1,110 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGrid + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + } + ] + ) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + } + ] + ) + + +def make_image( + tile_count: int, tile_size: int, random_offset: bool = False, filter_mode: Optional[str] = None, seed=123, **kwargs +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + image = np.random.randint(200, size=[3, tile_count * tile_size + pad, tile_count * tile_size + pad], dtype=np.uint8) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + image = image[ + :, random_state.randint(image.shape[1] % tile_size) :, random_state.randint(image.shape[2] % tile_size) : + ] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * tile_size : (x + 1) * tile_size, y * tile_size : (y + 1) * tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if filter_mode == "min" or filter_mode == "max": + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGrid(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_tile_pathce_single_call(self, input_parameters): + + img, tiles = make_image(**input_parameters) + + tiler = TileOnGrid(**input_parameters) + output = tiler(img) + np.testing.assert_equal(output, tiles) + + @parameterized.expand(TEST_CASES2) + def test_tile_pathce_random_call(self, input_parameters): + + img, tiles = make_image(**input_parameters, seed=123) + + tiler = TileOnGrid(**input_parameters) + tiler.set_random_state(seed=123) + + output = tiler(img) + np.testing.assert_equal(output, tiles) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py new file mode 100644 index 0000000000..d660193636 --- /dev/null +++ b/tests/test_tile_on_grid_dict.py @@ -0,0 +1,123 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGridDict + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + } + ] + ) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + } + ] + ) + + +def make_image( + tile_count: int, tile_size: int, random_offset: bool = False, filter_mode: Optional[str] = None, seed=123, **kwargs +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + image = np.random.randint(200, size=[3, tile_count * tile_size + pad, tile_count * tile_size + pad], dtype=np.uint8) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + image = image[ + :, random_state.randint(image.shape[1] % tile_size) :, random_state.randint(image.shape[2] % tile_size) : + ] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * tile_size : (x + 1) * tile_size, y * tile_size : (y + 1) * tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if filter_mode == "min" or filter_mode == "max": + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGridDict(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_tile_pathce_single_call(self, input_parameters): + + key = "image" + input_parameters["keys"] = key + + img, tiles = make_image(**input_parameters) + + splitter = TileOnGridDict(**input_parameters) + + output = splitter({key: img}) + output = output[key] + + np.testing.assert_equal(tiles, output) + + @parameterized.expand(TEST_CASES2) + def test_tile_pathce_random_call(self, input_parameters): + + key = "image" + input_parameters["keys"] = key + + random_state = np.random.RandomState(123) + seed = random_state.randint(10000) + img, tiles = make_image(**input_parameters, seed=seed) + + splitter = TileOnGridDict(**input_parameters) + splitter.set_random_state(seed=123) + + output = splitter({key: img}) + output = output[key] + + np.testing.assert_equal(tiles, output) + + +if __name__ == "__main__": + unittest.main() From 204e95aae8a4c2f5393b4a2f0fbf569b8172be9c Mon Sep 17 00:00:00 2001 From: myron Date: Mon, 1 Nov 2021 17:53:58 -0700 Subject: [PATCH 2/5] MIL component to extract patches Signed-off-by: myron --- monai/apps/pathology/transforms/__init__.py | 4 +- .../pathology/transforms/spatial/__init__.py | 4 +- .../pathology/transforms/spatial/array.py | 157 +++++++++++++++++- .../transforms/spatial/dictionary.py | 84 +++++++++- tests/test_tile_on_grid.py | 110 ++++++++++++ tests/test_tile_on_grid_dict.py | 123 ++++++++++++++ 6 files changed, 471 insertions(+), 11 deletions(-) create mode 100644 tests/test_tile_on_grid.py create mode 100644 tests/test_tile_on_grid_dict.py diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 1be96b8e34..b418e20279 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .spatial.array import SplitOnGrid -from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .spatial.array import SplitOnGrid, TileOnGrid +from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains from .stain.dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py index 07ba222ab0..c9971254e7 100644 --- a/monai/apps/pathology/transforms/spatial/__init__.py +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .array import SplitOnGrid -from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .array import SplitOnGrid, TileOnGrid +from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 4edf987610..36e2fc2ee2 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Tuple, Union +import numpy as np import torch +from numpy.lib.stride_tricks import as_strided -from monai.transforms.transform import Transform +from monai.transforms.transform import Randomizable, Transform -__all__ = ["SplitOnGrid"] +__all__ = ["SplitOnGrid", "TileOnGrid"] class SplitOnGrid(Transform): @@ -73,3 +75,152 @@ def get_params(self, image_size): ) return patch_size, steps + + +class TileOnGrid(Randomizable, Transform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to None (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + def __init__( + self, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: Optional[str] = "min", + ): + self.tile_count = tile_count + self.tile_size = tile_size + self.step = step + self.random_offset = random_offset + self.pad_full = pad_full + self.background_val = background_val + self.filter_mode = filter_mode + + if self.step is None: + self.step = self.tile_size # non-overlapping grid + + self.offset = (0, 0) + self.random_idxs = [0] + + def randomize(self, img_size: Sequence[int]) -> None: + + c, h, w = img_size + tile_step: int = self.step # type: ignore + + if self.random_offset: + pad_h = h % self.tile_size + pad_w = w % self.tile_size + if pad_h > 0 and pad_w > 0: + self.offset = (self.R.randint(pad_h), self.R.randint(pad_w)) + h = h - self.offset[0] + w = w - self.offset[1] + else: + self.offset = (0, 0) + + if self.pad_full: + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + h = h + pad_h + w = w + pad_w + + h_n = (h - self.tile_size + tile_step) // tile_step + w_n = (w - self.tile_size + tile_step) // tile_step + tile_total = h_n * w_n + + if self.tile_count is not None and tile_total > self.tile_count: + self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) # type: ignore + else: + self.random_idxs = [0] # type: ignore + + def __call__(self, image: np.ndarray) -> np.ndarray: + + # add random offset + self.randomize(img_size=image.shape) + tile_step: int = self.step # type: ignore + + if self.random_offset and self.offset is not None: + image = image[:, self.offset[0] :, self.offset[1] :] + + # pad to full size, divisible by tile_size + if self.pad_full: + c, h, w = image.shape + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + image = np.pad( + image, + [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], + constant_values=self.background_val, + ) + + # extact tiles + xstep, ystep = tile_step, tile_step + xsize, ysize = self.tile_size, self.tile_size + clen, xlen, ylen = image.shape + cstride, xstride, ystride = image.strides + llw = as_strided( + image, + shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize), + strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride), + writeable=False, + ) + image = llw.reshape(-1, clen, xsize, ysize) + + # if keeping all patches + if self.tile_count is None: + # retain only patches with significant foreground content to speed up inference + # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference + thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size + if self.filter_mode == "min": + # default, keep non-background tiles (small values) + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh) + image = image[idxs.reshape(-1)] + elif self.filter_mode == "max": + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh) + image = image[idxs.reshape(-1)] + + else: + if len(image) >= self.tile_count: + + if self.filter_mode == "min": + # default, keep non-background tiles (smallest values) + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count] + image = image[idxs] + elif self.filter_mode == "max": + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :] + image = image[idxs] + elif len(image) > self.tile_count: + # random subset (more appropriate for WSIs without distinct background) + if self.random_idxs is not None: + image = image[self.random_idxs] + + else: + image = np.pad( + image, + [[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]], + constant_values=self.background_val, + ) + + return image diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 10b01a39de..28d72b8147 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -9,16 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Mapping, Optional, Tuple, Union +import copy +from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union +import numpy as np import torch from monai.config import KeysCollection -from monai.transforms.transform import MapTransform +from monai.transforms.transform import MapTransform, Randomizable -from .array import SplitOnGrid +from .array import SplitOnGrid, TileOnGrid -__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"] +__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"] class SplitOnGridd(MapTransform): @@ -53,4 +55,78 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class TileOnGridd(Randomizable, MapTransform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to None (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + def __init__( + self, + keys: KeysCollection, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: Optional[str] = "min", + allow_missing_keys: bool = False, + return_list_of_dicts: bool = False, + ): + super().__init__(keys, allow_missing_keys) + + self.return_list_of_dicts = return_list_of_dicts + self.seed = None + + self.splitter = TileOnGrid( + tile_count=tile_count, + tile_size=tile_size, + step=step, + random_offset=random_offset, + pad_full=pad_full, + background_val=background_val, + filter_mode=filter_mode, + ) + + def randomize(self, data: Any = None) -> None: + self.seed = self.R.randint(10000) # type: ignore + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]: + + self.randomize() + + d = dict(data) + for key in self.key_iterator(d): + self.splitter.set_random_state(seed=self.seed) # same random seed for all keys + d[key] = self.splitter(d[key]) + + if self.return_list_of_dicts: + d_list = [] + for i in range(len(d[self.keys[0]])): + d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()}) + d = d_list # type: ignore + + return d + + SplitOnGridDict = SplitOnGridD = SplitOnGridd +TileOnGridDict = TileOnGridD = TileOnGridd diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py new file mode 100644 index 0000000000..a35bf4db68 --- /dev/null +++ b/tests/test_tile_on_grid.py @@ -0,0 +1,110 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGrid + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + } + ] + ) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + } + ] + ) + + +def make_image( + tile_count: int, tile_size: int, random_offset: bool = False, filter_mode: Optional[str] = None, seed=123, **kwargs +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + image = np.random.randint(200, size=[3, tile_count * tile_size + pad, tile_count * tile_size + pad], dtype=np.uint8) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + image = image[ + :, random_state.randint(image.shape[1] % tile_size) :, random_state.randint(image.shape[2] % tile_size) : + ] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * tile_size : (x + 1) * tile_size, y * tile_size : (y + 1) * tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if filter_mode == "min" or filter_mode == "max": + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGrid(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_tile_pathce_single_call(self, input_parameters): + + img, tiles = make_image(**input_parameters) + + tiler = TileOnGrid(**input_parameters) + output = tiler(img) + np.testing.assert_equal(output, tiles) + + @parameterized.expand(TEST_CASES2) + def test_tile_pathce_random_call(self, input_parameters): + + img, tiles = make_image(**input_parameters, seed=123) + + tiler = TileOnGrid(**input_parameters) + tiler.set_random_state(seed=123) + + output = tiler(img) + np.testing.assert_equal(output, tiles) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py new file mode 100644 index 0000000000..d660193636 --- /dev/null +++ b/tests/test_tile_on_grid_dict.py @@ -0,0 +1,123 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGridDict + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + } + ] + ) + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in [None, "min", "max"]: + for background_val in [255, 0]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + } + ] + ) + + +def make_image( + tile_count: int, tile_size: int, random_offset: bool = False, filter_mode: Optional[str] = None, seed=123, **kwargs +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + image = np.random.randint(200, size=[3, tile_count * tile_size + pad, tile_count * tile_size + pad], dtype=np.uint8) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + image = image[ + :, random_state.randint(image.shape[1] % tile_size) :, random_state.randint(image.shape[2] % tile_size) : + ] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * tile_size : (x + 1) * tile_size, y * tile_size : (y + 1) * tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if filter_mode == "min" or filter_mode == "max": + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGridDict(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_tile_pathce_single_call(self, input_parameters): + + key = "image" + input_parameters["keys"] = key + + img, tiles = make_image(**input_parameters) + + splitter = TileOnGridDict(**input_parameters) + + output = splitter({key: img}) + output = output[key] + + np.testing.assert_equal(tiles, output) + + @parameterized.expand(TEST_CASES2) + def test_tile_pathce_random_call(self, input_parameters): + + key = "image" + input_parameters["keys"] = key + + random_state = np.random.RandomState(123) + seed = random_state.randint(10000) + img, tiles = make_image(**input_parameters, seed=seed) + + splitter = TileOnGridDict(**input_parameters) + splitter.set_random_state(seed=123) + + output = splitter({key: img}) + output = output[key] + + np.testing.assert_equal(tiles, output) + + +if __name__ == "__main__": + unittest.main() From a680d02d60b44516b147052495877ce5ce862989 Mon Sep 17 00:00:00 2001 From: myron Date: Fri, 12 Nov 2021 14:31:14 -0800 Subject: [PATCH 3/5] random flag, minor fixes Signed-off-by: myron --- monai/apps/pathology/transforms/spatial/array.py | 9 ++++++--- monai/apps/pathology/transforms/spatial/dictionary.py | 6 +++--- tests/test_tile_on_grid.py | 4 ++-- tests/test_tile_on_grid_dict.py | 4 ++-- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 4692b80c3a..1ad35f8ab1 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -95,8 +95,8 @@ class TileOnGrid(Randomizable, Transform): Defaults to ``False``. background_val: the background constant (e.g. 255 for white background) Defaults to ``255``. - filter_mode: mode must be in ["min", "max", "random", None]. If total number of tiles is more than tile_size, - then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random or None) subset + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset Defaults to ``min`` (which assumes background is high value) """ @@ -109,7 +109,7 @@ def __init__( random_offset: bool = False, pad_full: bool = False, background_val: int = 255, - filter_mode: Optional[str] = "min", + filter_mode: str = "min", ): self.tile_count = tile_count self.tile_size = tile_size @@ -125,6 +125,9 @@ def __init__( self.offset = (0, 0) self.random_idxs = np.array((0,)) + if self.filter_mode not in ["min", "max", "random"]: + raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode)) + def randomize(self, img_size: Sequence[int]) -> None: c, h, w = img_size diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index c9fd97d864..0168ac3108 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -73,8 +73,8 @@ class TileOnGridd(Randomizable, MapTransform): Defaults to ``False``. background_val: the background constant (e.g. 255 for white background) Defaults to ``255``. - filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more than tile_size, - then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset Defaults to ``min`` (which assumes background is high value) """ @@ -88,7 +88,7 @@ def __init__( random_offset: bool = False, pad_full: bool = False, background_val: int = 255, - filter_mode: Optional[str] = "min", + filter_mode: str = "min", allow_missing_keys: bool = False, return_list_of_dicts: bool = False, ): diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index a35bf4db68..7a81cac5f4 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -20,7 +20,7 @@ TEST_CASES = [] for tile_count in [16, 64]: for tile_size in [8, 32]: - for filter_mode in [None, "min", "max"]: + for filter_mode in ["min", "max", "random"]: for background_val in [255, 0]: TEST_CASES.append( [ @@ -37,7 +37,7 @@ TEST_CASES2 = [] for tile_count in [16, 64]: for tile_size in [8, 32]: - for filter_mode in [None, "min", "max"]: + for filter_mode in ["min", "max", "random"]: for background_val in [255, 0]: TEST_CASES2.append( [ diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index d660193636..f2b28714ba 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -20,7 +20,7 @@ TEST_CASES = [] for tile_count in [16, 64]: for tile_size in [8, 32]: - for filter_mode in [None, "min", "max"]: + for filter_mode in ["min", "max", "random"]: for background_val in [255, 0]: TEST_CASES.append( [ @@ -37,7 +37,7 @@ TEST_CASES2 = [] for tile_count in [16, 64]: for tile_size in [8, 32]: - for filter_mode in [None, "min", "max"]: + for filter_mode in ["min", "max", "random"]: for background_val in [255, 0]: TEST_CASES2.append( [ From 715bc98c6f4366c54a75d7c3460a3c81c1b5509b Mon Sep 17 00:00:00 2001 From: myron Date: Sat, 13 Nov 2021 17:03:43 -0800 Subject: [PATCH 4/5] minor fixes for padding Signed-off-by: myron --- .../pathology/transforms/spatial/array.py | 18 +++++------ tests/test_tile_on_grid.py | 30 ++++++++++--------- tests/test_tile_on_grid_dict.py | 9 +++--- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 1ad35f8ab1..e08ac7f46f 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -133,15 +133,13 @@ def randomize(self, img_size: Sequence[int]) -> None: c, h, w = img_size tile_step = cast(int, self.step) + self.offset = (0, 0) if self.random_offset: pad_h = h % self.tile_size pad_w = w % self.tile_size - if pad_h > 0 and pad_w > 0: - self.offset = (self.R.randint(pad_h), self.R.randint(pad_w)) - h = h - self.offset[0] - w = w - self.offset[1] - else: - self.offset = (0, 0) + self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0) + h = h - self.offset[0] + w = w - self.offset[1] if self.pad_full: pad_h = (self.tile_size - h % self.tile_size) % self.tile_size @@ -164,7 +162,7 @@ def __call__(self, image: np.ndarray) -> np.ndarray: self.randomize(img_size=image.shape) tile_step = cast(int, self.step) - if self.random_offset and self.offset[0] > 0 and self.offset[1] > 0: + if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): image = image[:, self.offset[0] :, self.offset[1] :] # pad to full size, divisible by tile_size @@ -205,7 +203,7 @@ def __call__(self, image: np.ndarray) -> np.ndarray: image = image[idxs.reshape(-1)] else: - if len(image) >= self.tile_count: + if len(image) > self.tile_count: if self.filter_mode == "min": # default, keep non-background tiles (smallest values) @@ -214,12 +212,12 @@ def __call__(self, image: np.ndarray) -> np.ndarray: elif self.filter_mode == "max": idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :] image = image[idxs] - elif len(image) > self.tile_count: + else: # random subset (more appropriate for WSIs without distinct background) if self.random_idxs is not None: image = image[self.random_idxs] - else: + elif len(image) < self.tile_count: image = np.pad( image, [[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]], diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index 7a81cac5f4..e5d8a733ea 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -34,6 +34,7 @@ ] ) + TEST_CASES2 = [] for tile_count in [16, 64]: for tile_size in [8, 32]: @@ -67,9 +68,10 @@ def make_image( random_state = np.random.RandomState(seed) if random_offset: - image = image[ - :, random_state.randint(image.shape[1] % tile_size) :, random_state.randint(image.shape[2] % tile_size) : - ] + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] tiles_list = [] for x in range(tile_count): @@ -78,7 +80,7 @@ def make_image( tiles = np.stack(tiles_list, axis=0) # type: ignore - if filter_mode == "min" or filter_mode == "max": + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] return imlarge, tiles @@ -94,16 +96,16 @@ def test_tile_pathce_single_call(self, input_parameters): output = tiler(img) np.testing.assert_equal(output, tiles) - @parameterized.expand(TEST_CASES2) - def test_tile_pathce_random_call(self, input_parameters): - - img, tiles = make_image(**input_parameters, seed=123) - - tiler = TileOnGrid(**input_parameters) - tiler.set_random_state(seed=123) - - output = tiler(img) - np.testing.assert_equal(output, tiles) + # @parameterized.expand(TEST_CASES2) + # def test_tile_pathce_random_call(self, input_parameters): + # + # img, tiles = make_image(**input_parameters, seed=123) + # + # tiler = TileOnGrid(**input_parameters) + # tiler.set_random_state(seed=123) + # + # output = tiler(img) + # np.testing.assert_equal(output, tiles) if __name__ == "__main__": diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index f2b28714ba..8428948adc 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -67,9 +67,10 @@ def make_image( random_state = np.random.RandomState(seed) if random_offset: - image = image[ - :, random_state.randint(image.shape[1] % tile_size) :, random_state.randint(image.shape[2] % tile_size) : - ] + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] tiles_list = [] for x in range(tile_count): @@ -78,7 +79,7 @@ def make_image( tiles = np.stack(tiles_list, axis=0) # type: ignore - if filter_mode == "min" or filter_mode == "max": + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] return imlarge, tiles From 3586f6528acdbeb8f5efb6a075f04450f1771c08 Mon Sep 17 00:00:00 2001 From: myron Date: Sun, 14 Nov 2021 22:00:16 -0800 Subject: [PATCH 5/5] improve tests Signed-off-by: myron --- tests/test_tile_on_grid.py | 22 +++++------ tests/test_tile_on_grid_dict.py | 65 ++++++++++++++++++++------------- 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index e5d8a733ea..f8c86fa90a 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -88,7 +88,7 @@ def make_image( class TestTileOnGrid(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_tile_pathce_single_call(self, input_parameters): + def test_tile_patch_single_call(self, input_parameters): img, tiles = make_image(**input_parameters) @@ -96,16 +96,16 @@ def test_tile_pathce_single_call(self, input_parameters): output = tiler(img) np.testing.assert_equal(output, tiles) - # @parameterized.expand(TEST_CASES2) - # def test_tile_pathce_random_call(self, input_parameters): - # - # img, tiles = make_image(**input_parameters, seed=123) - # - # tiler = TileOnGrid(**input_parameters) - # tiler.set_random_state(seed=123) - # - # output = tiler(img) - # np.testing.assert_equal(output, tiles) + @parameterized.expand(TEST_CASES2) + def test_tile_patch_random_call(self, input_parameters): + + img, tiles = make_image(**input_parameters, seed=123) + + tiler = TileOnGrid(**input_parameters) + tiler.set_random_state(seed=123) + + output = tiler(img) + np.testing.assert_equal(output, tiles) if __name__ == "__main__": diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index 8428948adc..95cfa179dd 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -22,34 +22,39 @@ for tile_size in [8, 32]: for filter_mode in ["min", "max", "random"]: for background_val in [255, 0]: - TEST_CASES.append( - [ - { - "tile_count": tile_count, - "tile_size": tile_size, - "filter_mode": filter_mode, - "random_offset": False, - "background_val": background_val, - } - ] - ) + for return_list_of_dicts in [False, True]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) + TEST_CASES2 = [] for tile_count in [16, 64]: for tile_size in [8, 32]: for filter_mode in ["min", "max", "random"]: for background_val in [255, 0]: - TEST_CASES2.append( - [ - { - "tile_count": tile_count, - "tile_size": tile_size, - "filter_mode": filter_mode, - "random_offset": True, - "background_val": background_val, - } - ] - ) + for return_list_of_dicts in [False, True]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) def make_image( @@ -87,7 +92,7 @@ def make_image( class TestTileOnGridDict(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_tile_pathce_single_call(self, input_parameters): + def test_tile_patch_single_call(self, input_parameters): key = "image" input_parameters["keys"] = key @@ -97,12 +102,16 @@ def test_tile_pathce_single_call(self, input_parameters): splitter = TileOnGridDict(**input_parameters) output = splitter({key: img}) - output = output[key] + + if input_parameters.get("return_list_of_dicts", False): + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] np.testing.assert_equal(tiles, output) @parameterized.expand(TEST_CASES2) - def test_tile_pathce_random_call(self, input_parameters): + def test_tile_patch_random_call(self, input_parameters): key = "image" input_parameters["keys"] = key @@ -115,7 +124,11 @@ def test_tile_pathce_random_call(self, input_parameters): splitter.set_random_state(seed=123) output = splitter({key: img}) - output = output[key] + + if input_parameters.get("return_list_of_dicts", False): + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] np.testing.assert_equal(tiles, output)