diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 2cd8ee861e..4c69a61b15 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -21,7 +21,6 @@ from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import ( - compute_bounding_rect, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, map_binary_to_indices, @@ -638,11 +637,38 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N class BoundingRect(Transform): """ Compute coordinates of axis-aligned bounding rectangles from input image `img`. + The output format of the coordinates is (shape is [channel, 2 * spatial dims]): + + [[1st_spatial_dim_start, 1st_spatial_dim_end, + 2nd_spatial_dim_start, 2nd_spatial_dim_end, + ..., + Nth_spatial_dim_start, Nth_spatial_dim_end], + + ... + + [1st_spatial_dim_start, 1st_spatial_dim_end, + 2nd_spatial_dim_start, 2nd_spatial_dim_end, + ..., + Nth_spatial_dim_start, Nth_spatial_dim_end]] + + The bounding boxes edges are aligned with the input image edges. + This function returns [-1, -1, ...] if there's no positive intensity. + + Args: + select_fn: function to select expected foreground, default is to select values > 0. """ + def __init__(self, select_fn: Callable = lambda x: x > 0) -> None: + self.select_fn = select_fn + def __call__(self, img: np.ndarray) -> np.ndarray: """ - See also: :py:class:`monai.transforms.utils.compute_bounding_rect`. + See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ - bbox = [compute_bounding_rect(channel) for channel in img] + bbox = list() + + for channel in range(img.shape[0]): + start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel) + bbox.append([i for k in zip(start_, end_) for i in k]) + return np.stack(bbox, axis=0) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 557e80474b..8e927eb605 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -590,16 +590,17 @@ class BoundingRectd(MapTransform): See also: monai.transforms.MapTransform bbox_key_postfix: the output bounding box coordinates will be written to the value of `{key}_{bbox_key_postfix}`. + select_fn: function to select expected foreground, default is to select values > 0. """ - def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox"): + def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox", select_fn: Callable = lambda x: x > 0): super().__init__(keys=keys) - self.bbox = BoundingRect() + self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: """ - See also: :py:class:`monai.transforms.utils.compute_bounding_rect`. + See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ d = dict(data) for key in self.keys: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 54d0632d9c..44205e4e09 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -516,6 +516,13 @@ def generate_spatial_bounding_box( generate the spatial bounding box of foreground in the image with start-end positions. Users can define arbitrary function to select expected foreground from the whole image or specified channels. And it can also add margin to every dim of the bounding box. + The output format of the coordinates is: + + [1st_spatial_dim_start, 2nd_spatial_dim_start, ..., Nth_spatial_dim_start], + [1st_spatial_dim_end, 2nd_spatial_dim_end, ..., Nth_spatial_dim_end] + + The bounding boxes edges are aligned with the input image edges. + This function returns [-1, -1, ...], [-1, -1, ...] if there's no positive intensity. Args: img: source image to generate bounding box from. @@ -524,44 +531,27 @@ def generate_spatial_bounding_box( of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ - data = img[[*(ensure_tuple(channel_indices))]] if channel_indices is not None else img + data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img data = np.any(select_fn(data), axis=0) - nonzero_idx = np.nonzero(data) - margin = ensure_tuple_rep(margin, data.ndim) - - box_start = list() - box_end = list() - for i in range(data.ndim): - assert len(nonzero_idx[i]) > 0, f"did not find nonzero index at spatial dim {i}" - box_start.append(max(0, np.min(nonzero_idx[i]) - margin[i])) - box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin[i] + 1)) - return box_start, box_end - - -def compute_bounding_rect(image: np.array): - """ - Compute ND coordinates of a bounding rectangle from the positive intensities. - The output format of the coordinates is: + ndim = len(data.shape) + margin = ensure_tuple_rep(margin, ndim) + for m in margin: + if m < 0: + raise ValueError("margin value should not be negative number.") - [1st_spatial_dim_start, 1st_spatial_dim_end, - 2nd_spatial_dim_start, 2nd_spatial_dim_end, - ..., - Nth_spatial_dim_start, Nth_spatial_dim_end,] + box_start = [0] * ndim + box_end = [0] * ndim - The bounding boxes edges are aligned with the input image edges. - This function returns [-1, -1, ...] if there's no positive intensity. - """ - _binary_image = image > 0 - ndim = len(_binary_image.shape) - bbox = [0] * (2 * ndim) for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): - dt = _binary_image.any(axis=ax) + dt = data.any(axis=ax) if not np.any(dt): - return np.asarray([-1] * len(bbox)) - min_d = np.argmax(dt) - max_d = max(_binary_image.shape[di] - np.argmax(dt[::-1]), min_d + 1) - bbox[di * 2], bbox[di * 2 + 1] = min_d, max_d - return np.asarray(bbox) + return [-1] * ndim, [-1] * ndim + + min_d = max(np.argmax(dt) - margin[di], 0) + max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1) + box_start[di], box_end[di] = min_d, max_d + + return box_start, box_end def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor: diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index 04fea8a22c..69476479a3 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -38,6 +38,12 @@ def test_shape(self, input_shape, expected): result = BoundingRect()(test_data) np.testing.assert_allclose(result, expected) + def test_select_fn(self): + test_data = np.random.randint(0, 8, size=(2, 3)) + test_data = test_data == 7 + bbox = BoundingRect(select_fn=lambda x: x < 1)(test_data) + np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) + if __name__ == "__main__": unittest.main()