diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index f02fb141e4..65fceebe8d 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -105,6 +105,12 @@ Crop and Pad :members: :special-members: __call__ +`BoundingRect` +"""""""""""""" +.. autoclass:: BoundingRect + :members: + :special-members: __call__ + Intensity ^^^^^^^^^ @@ -564,6 +570,12 @@ Crop and Pad (Dict) :members: :special-members: __call__ +`BoundingRectd` +""""""""""""""" +.. autoclass:: BoundingRectd + :members: + :special-members: __call__ + Instensity (Dict) ^^^^^^^^^^^^^^^^^ diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index d8075e5d01..2cd8ee861e 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -21,6 +21,7 @@ from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import ( + compute_bounding_rect, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, map_binary_to_indices, @@ -632,3 +633,16 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ return self.padder(self.cropper(img), mode=mode) + + +class BoundingRect(Transform): + """ + Compute coordinates of axis-aligned bounding rectangles from input image `img`. + """ + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + See also: :py:class:`monai.transforms.utils.compute_bounding_rect`. + """ + bbox = [compute_bounding_rect(channel) for channel in img] + return np.stack(bbox, axis=0) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 7d0b4f85cd..557e80474b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -24,6 +24,7 @@ from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.croppad.array import ( BorderPad, + BoundingRect, CenterSpatialCrop, DivisiblePad, ResizeWithPadOrCrop, @@ -580,6 +581,36 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class BoundingRectd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.BoundingRect`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + bbox_key_postfix: the output bounding box coordinates will be + written to the value of `{key}_{bbox_key_postfix}`. + """ + + def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox"): + super().__init__(keys=keys) + self.bbox = BoundingRect() + self.bbox_key_postfix = bbox_key_postfix + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + """ + See also: :py:class:`monai.transforms.utils.compute_bounding_rect`. + """ + d = dict(data) + for key in self.keys: + bbox = self.bbox(d[key]) + key_to_add = f"{key}_{self.bbox_key_postfix}" + if key_to_add in d: + raise KeyError(f"Bounding box data with key {key_to_add} already exists.") + d[key_to_add] = bbox + return d + + SpatialPadD = SpatialPadDict = SpatialPadd BorderPadD = BorderPadDict = BorderPadd DivisiblePadD = DivisiblePadDict = DivisiblePadd @@ -591,3 +622,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandWeightedCropD = RandWeightedCropDict = RandWeightedCropd RandCropByPosNegLabelD = RandCropByPosNegLabelDict = RandCropByPosNegLabeld ResizeWithPadOrCropD = ResizeWithPadOrCropDict = ResizeWithPadOrCropd +BoundingRectD = BoundingRectDict = BoundingRectd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 15d31016f2..54d0632d9c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import random import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union @@ -537,6 +538,32 @@ def generate_spatial_bounding_box( return box_start, box_end +def compute_bounding_rect(image: np.array): + """ + Compute ND coordinates of a bounding rectangle from the positive intensities. + The output format of the coordinates is: + + [1st_spatial_dim_start, 1st_spatial_dim_end, + 2nd_spatial_dim_start, 2nd_spatial_dim_end, + ..., + Nth_spatial_dim_start, Nth_spatial_dim_end,] + + The bounding boxes edges are aligned with the input image edges. + This function returns [-1, -1, ...] if there's no positive intensity. + """ + _binary_image = image > 0 + ndim = len(_binary_image.shape) + bbox = [0] * (2 * ndim) + for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): + dt = _binary_image.any(axis=ax) + if not np.any(dt): + return np.asarray([-1] * len(bbox)) + min_d = np.argmax(dt) + max_d = max(_binary_image.shape[di] - np.argmax(dt[::-1]), min_d + 1) + bbox[di * 2], bbox[di * 2 + 1] = min_d, max_d + return np.asarray(bbox) + + def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor: """ Gets the largest connected component mask of an image. diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py new file mode 100644 index 0000000000..04fea8a22c --- /dev/null +++ b/tests/test_bounding_rect.py @@ -0,0 +1,43 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +import monai +from monai.transforms import BoundingRect + +TEST_CASE_1 = [(2, 3), [[-1, -1], [1, 2]]] + +TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] + +TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] + + +class TestBoundingRect(unittest.TestCase): + def setUp(self): + monai.utils.set_determinism(1) + + def tearDown(self): + monai.utils.set_determinism(None) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_shape, expected): + test_data = np.random.randint(0, 8, size=input_shape) + test_data = test_data == 7 + result = BoundingRect()(test_data) + np.testing.assert_allclose(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py new file mode 100644 index 0000000000..c33a3c371d --- /dev/null +++ b/tests/test_bounding_rectd.py @@ -0,0 +1,49 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +import monai +from monai.transforms import BoundingRectD + +TEST_CASE_1 = [(2, 3), [[-1, -1], [1, 2]]] + +TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] + +TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] + + +class TestBoundingRectD(unittest.TestCase): + def setUp(self): + monai.utils.set_determinism(1) + + def tearDown(self): + monai.utils.set_determinism(None) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_shape, expected): + test_data = np.random.randint(0, 8, size=input_shape) + test_data = test_data == 7 + result = BoundingRectD("image")({"image": test_data}) + np.testing.assert_allclose(result["image_bbox"], expected) + + result = BoundingRectD("image", "cc")({"image": test_data}) + np.testing.assert_allclose(result["image_cc"], expected) + + with self.assertRaises(KeyError): + BoundingRectD("image", "cc")({"image": test_data, "image_cc": None}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index aadf379dcb..aa0fd57f76 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -243,7 +243,7 @@ def test_training(self): repeated.append(results) np.testing.assert_allclose(repeated[0], repeated[1]) - @TimedCall(seconds=1000, skip_timing=not torch.cuda.is_available(), daemon=False, force_quit=False) + @TimedCall(seconds=1000, skip_timing=not torch.cuda.is_available(), daemon=False) def test_timing(self): self.train_and_infer()