Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import (
compute_bounding_rect,
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
map_binary_to_indices,
Expand Down Expand Up @@ -638,11 +637,38 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
class BoundingRect(Transform):
"""
Compute coordinates of axis-aligned bounding rectangles from input image `img`.
The output format of the coordinates is (shape is [channel, 2 * spatial dims]):

[[1st_spatial_dim_start, 1st_spatial_dim_end,
2nd_spatial_dim_start, 2nd_spatial_dim_end,
...,
Nth_spatial_dim_start, Nth_spatial_dim_end],

...

[1st_spatial_dim_start, 1st_spatial_dim_end,
2nd_spatial_dim_start, 2nd_spatial_dim_end,
...,
Nth_spatial_dim_start, Nth_spatial_dim_end]]

The bounding boxes edges are aligned with the input image edges.
This function returns [-1, -1, ...] if there's no positive intensity.

Args:
select_fn: function to select expected foreground, default is to select values > 0.
"""

def __init__(self, select_fn: Callable = lambda x: x > 0) -> None:
self.select_fn = select_fn

def __call__(self, img: np.ndarray) -> np.ndarray:
"""
See also: :py:class:`monai.transforms.utils.compute_bounding_rect`.
See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`.
"""
bbox = [compute_bounding_rect(channel) for channel in img]
bbox = list()

for channel in range(img.shape[0]):
start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel)
bbox.append([i for k in zip(start_, end_) for i in k])

return np.stack(bbox, axis=0)
7 changes: 4 additions & 3 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,17 @@ class BoundingRectd(MapTransform):
See also: monai.transforms.MapTransform
bbox_key_postfix: the output bounding box coordinates will be
written to the value of `{key}_{bbox_key_postfix}`.
select_fn: function to select expected foreground, default is to select values > 0.
"""

def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox"):
def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox", select_fn: Callable = lambda x: x > 0):
super().__init__(keys=keys)
self.bbox = BoundingRect()
self.bbox = BoundingRect(select_fn=select_fn)
self.bbox_key_postfix = bbox_key_postfix

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
"""
See also: :py:class:`monai.transforms.utils.compute_bounding_rect`.
See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`.
"""
d = dict(data)
for key in self.keys:
Expand Down
56 changes: 23 additions & 33 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,13 @@ def generate_spatial_bounding_box(
generate the spatial bounding box of foreground in the image with start-end positions.
Users can define arbitrary function to select expected foreground from the whole image or specified channels.
And it can also add margin to every dim of the bounding box.
The output format of the coordinates is:

[1st_spatial_dim_start, 2nd_spatial_dim_start, ..., Nth_spatial_dim_start],
[1st_spatial_dim_end, 2nd_spatial_dim_end, ..., Nth_spatial_dim_end]

The bounding boxes edges are aligned with the input image edges.
This function returns [-1, -1, ...], [-1, -1, ...] if there's no positive intensity.

Args:
img: source image to generate bounding box from.
Expand All @@ -524,44 +531,27 @@ def generate_spatial_bounding_box(
of image. if None, select foreground on the whole image.
margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
"""
data = img[[*(ensure_tuple(channel_indices))]] if channel_indices is not None else img
data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img
data = np.any(select_fn(data), axis=0)
nonzero_idx = np.nonzero(data)
margin = ensure_tuple_rep(margin, data.ndim)

box_start = list()
box_end = list()
for i in range(data.ndim):
assert len(nonzero_idx[i]) > 0, f"did not find nonzero index at spatial dim {i}"
box_start.append(max(0, np.min(nonzero_idx[i]) - margin[i]))
box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin[i] + 1))
return box_start, box_end


def compute_bounding_rect(image: np.array):
"""
Compute ND coordinates of a bounding rectangle from the positive intensities.
The output format of the coordinates is:
ndim = len(data.shape)
margin = ensure_tuple_rep(margin, ndim)
for m in margin:
if m < 0:
raise ValueError("margin value should not be negative number.")

[1st_spatial_dim_start, 1st_spatial_dim_end,
2nd_spatial_dim_start, 2nd_spatial_dim_end,
...,
Nth_spatial_dim_start, Nth_spatial_dim_end,]
box_start = [0] * ndim
box_end = [0] * ndim

The bounding boxes edges are aligned with the input image edges.
This function returns [-1, -1, ...] if there's no positive intensity.
"""
_binary_image = image > 0
ndim = len(_binary_image.shape)
bbox = [0] * (2 * ndim)
for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)):
dt = _binary_image.any(axis=ax)
dt = data.any(axis=ax)
if not np.any(dt):
return np.asarray([-1] * len(bbox))
min_d = np.argmax(dt)
max_d = max(_binary_image.shape[di] - np.argmax(dt[::-1]), min_d + 1)
bbox[di * 2], bbox[di * 2 + 1] = min_d, max_d
return np.asarray(bbox)
return [-1] * ndim, [-1] * ndim

min_d = max(np.argmax(dt) - margin[di], 0)
max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1)
box_start[di], box_end[di] = min_d, max_d

return box_start, box_end


def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_bounding_rect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def test_shape(self, input_shape, expected):
result = BoundingRect()(test_data)
np.testing.assert_allclose(result, expected)

def test_select_fn(self):
test_data = np.random.randint(0, 8, size=(2, 3))
test_data = test_data == 7
bbox = BoundingRect(select_fn=lambda x: x < 1)(test_data)
np.testing.assert_allclose(bbox, [[0, 3], [0, 3]])


if __name__ == "__main__":
unittest.main()