Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ Crop and Pad
:members:
:special-members: __call__

`BoundingRect`
""""""""""""""
.. autoclass:: BoundingRect
:members:
:special-members: __call__

Intensity
^^^^^^^^^

Expand Down Expand Up @@ -564,6 +570,12 @@ Crop and Pad (Dict)
:members:
:special-members: __call__

`BoundingRectd`
"""""""""""""""
.. autoclass:: BoundingRectd
:members:
:special-members: __call__

Instensity (Dict)
^^^^^^^^^^^^^^^^^

Expand Down
14 changes: 14 additions & 0 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import (
compute_bounding_rect,
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
map_binary_to_indices,
Expand Down Expand Up @@ -632,3 +633,16 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
"""
return self.padder(self.cropper(img), mode=mode)


class BoundingRect(Transform):
"""
Compute coordinates of axis-aligned bounding rectangles from input image `img`.
"""

def __call__(self, img: np.ndarray) -> np.ndarray:
"""
See also: :py:class:`monai.transforms.utils.compute_bounding_rect`.
"""
bbox = [compute_bounding_rect(channel) for channel in img]
return np.stack(bbox, axis=0)
32 changes: 32 additions & 0 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.transforms.compose import MapTransform, Randomizable
from monai.transforms.croppad.array import (
BorderPad,
BoundingRect,
CenterSpatialCrop,
DivisiblePad,
ResizeWithPadOrCrop,
Expand Down Expand Up @@ -580,6 +581,36 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class BoundingRectd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.BoundingRect`.

Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
bbox_key_postfix: the output bounding box coordinates will be
written to the value of `{key}_{bbox_key_postfix}`.
"""

def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox"):
super().__init__(keys=keys)
self.bbox = BoundingRect()
self.bbox_key_postfix = bbox_key_postfix

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
"""
See also: :py:class:`monai.transforms.utils.compute_bounding_rect`.
"""
d = dict(data)
for key in self.keys:
bbox = self.bbox(d[key])
key_to_add = f"{key}_{self.bbox_key_postfix}"
if key_to_add in d:
raise KeyError(f"Bounding box data with key {key_to_add} already exists.")
d[key_to_add] = bbox
return d


SpatialPadD = SpatialPadDict = SpatialPadd
BorderPadD = BorderPadDict = BorderPadd
DivisiblePadD = DivisiblePadDict = DivisiblePadd
Expand All @@ -591,3 +622,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
RandWeightedCropD = RandWeightedCropDict = RandWeightedCropd
RandCropByPosNegLabelD = RandCropByPosNegLabelDict = RandCropByPosNegLabeld
ResizeWithPadOrCropD = ResizeWithPadOrCropDict = ResizeWithPadOrCropd
BoundingRectD = BoundingRectDict = BoundingRectd
27 changes: 27 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import random
import warnings
from typing import Callable, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -537,6 +538,32 @@ def generate_spatial_bounding_box(
return box_start, box_end


def compute_bounding_rect(image: np.array):
"""
Compute ND coordinates of a bounding rectangle from the positive intensities.
The output format of the coordinates is:

[1st_spatial_dim_start, 1st_spatial_dim_end,
2nd_spatial_dim_start, 2nd_spatial_dim_end,
...,
Nth_spatial_dim_start, Nth_spatial_dim_end,]

The bounding boxes edges are aligned with the input image edges.
This function returns [-1, -1, ...] if there's no positive intensity.
"""
_binary_image = image > 0
ndim = len(_binary_image.shape)
bbox = [0] * (2 * ndim)
for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)):
dt = _binary_image.any(axis=ax)
if not np.any(dt):
return np.asarray([-1] * len(bbox))
min_d = np.argmax(dt)
max_d = max(_binary_image.shape[di] - np.argmax(dt[::-1]), min_d + 1)
bbox[di * 2], bbox[di * 2 + 1] = min_d, max_d
return np.asarray(bbox)


def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor:
"""
Gets the largest connected component mask of an image.
Expand Down
43 changes: 43 additions & 0 deletions tests/test_bounding_rect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

import monai
from monai.transforms import BoundingRect

TEST_CASE_1 = [(2, 3), [[-1, -1], [1, 2]]]

TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]]

TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]


class TestBoundingRect(unittest.TestCase):
def setUp(self):
monai.utils.set_determinism(1)

def tearDown(self):
monai.utils.set_determinism(None)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, input_shape, expected):
test_data = np.random.randint(0, 8, size=input_shape)
test_data = test_data == 7
result = BoundingRect()(test_data)
np.testing.assert_allclose(result, expected)


if __name__ == "__main__":
unittest.main()
49 changes: 49 additions & 0 deletions tests/test_bounding_rectd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

import monai
from monai.transforms import BoundingRectD

TEST_CASE_1 = [(2, 3), [[-1, -1], [1, 2]]]

TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]]

TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]


class TestBoundingRectD(unittest.TestCase):
def setUp(self):
monai.utils.set_determinism(1)

def tearDown(self):
monai.utils.set_determinism(None)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, input_shape, expected):
test_data = np.random.randint(0, 8, size=input_shape)
test_data = test_data == 7
result = BoundingRectD("image")({"image": test_data})
np.testing.assert_allclose(result["image_bbox"], expected)

result = BoundingRectD("image", "cc")({"image": test_data})
np.testing.assert_allclose(result["image_cc"], expected)

with self.assertRaises(KeyError):
BoundingRectD("image", "cc")({"image": test_data, "image_cc": None})


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_integration_classification_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_training(self):
repeated.append(results)
np.testing.assert_allclose(repeated[0], repeated[1])

@TimedCall(seconds=1000, skip_timing=not torch.cuda.is_available(), daemon=False, force_quit=False)
@TimedCall(seconds=1000, skip_timing=not torch.cuda.is_available(), daemon=False)
def test_timing(self):
self.train_and_infer()

Expand Down