diff --git a/monai/apps/detection/networks/retinanet_detector.py b/monai/apps/detection/networks/retinanet_detector.py index a6a4cf4e56..69272216c4 100644 --- a/monai/apps/detection/networks/retinanet_detector.py +++ b/monai/apps/detection/networks/retinanet_detector.py @@ -485,7 +485,9 @@ def forward( """ # 1. Check if input arguments are valid if self.training: - check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key) + targets = check_training_targets( + input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key + ) self._check_detector_training_components() # 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible. @@ -877,7 +879,7 @@ def get_cls_train_sample_per_image( foreground_idxs_per_image = matched_idxs_per_image >= 0 - num_foreground = foreground_idxs_per_image.sum() + num_foreground = int(foreground_idxs_per_image.sum()) num_gt_box = targets_per_image[self.target_box_key].shape[0] if self.debug: diff --git a/monai/apps/detection/transforms/array.py b/monai/apps/detection/transforms/array.py index 491af077f0..d8ffce4584 100644 --- a/monai/apps/detection/transforms/array.py +++ b/monai/apps/detection/transforms/array.py @@ -28,6 +28,7 @@ convert_box_to_standard_mode, get_spatial_dims, spatial_crop_boxes, + standardize_empty_box, ) from monai.transforms import Rotate90, SpatialCrop from monai.transforms.transform import Transform @@ -46,6 +47,7 @@ ) __all__ = [ + "StandardizeEmptyBox", "ConvertBoxToStandardMode", "ConvertBoxMode", "AffineBox", @@ -60,6 +62,27 @@ ] +class StandardizeEmptyBox(Transform): + """ + When boxes are empty, this transform standardize it to shape of (0,4) or (0,6). + + Args: + spatial_dims: number of spatial dimensions of the bounding boxes. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, spatial_dims: int) -> None: + self.spatial_dims = spatial_dims + + def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Args: + boxes: source bounding boxes, Nx4 or Nx6 or 0xM torch tensor or ndarray. + """ + return standardize_empty_box(boxes, spatial_dims=self.spatial_dims) + + class ConvertBoxMode(Transform): """ This transform converts the boxes in src_mode to the dst_mode. diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index f77c5f4c48..a692a42369 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -34,6 +34,7 @@ MaskToBox, RotateBox90, SpatialCropBox, + StandardizeEmptyBox, ZoomBox, ) from monai.apps.detection.transforms.box_ops import convert_box_to_mask @@ -51,6 +52,9 @@ from monai.utils.type_conversion import convert_data_type, convert_to_tensor __all__ = [ + "StandardizeEmptyBoxd", + "StandardizeEmptyBoxD", + "StandardizeEmptyBoxDict", "ConvertBoxModed", "ConvertBoxModeD", "ConvertBoxModeDict", @@ -95,6 +99,50 @@ DEFAULT_POST_FIX = PostFix.meta() +class StandardizeEmptyBoxd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.StandardizeEmptyBox`. + + When boxes are empty, this transform standardize it to shape of (0,4) or (0,6). + + Example: + .. code-block:: python + + data = {"boxes": torch.ones(0,), "image": torch.ones(1, 128, 128, 128)} + box_converter = StandardizeEmptyBoxd(box_keys=["boxes"], box_ref_image_keys="image") + box_converter(data) + """ + + def __init__(self, box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False) -> None: + """ + Args: + box_keys: Keys to pick data for transformation. + box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached. + allow_missing_keys: don't raise exception if key is missing. + + See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxToStandardMode` + """ + super().__init__(box_keys, allow_missing_keys) + box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys) + if len(box_ref_image_keys_tuple) > 1: + raise ValueError( + "Please provide a single key for box_ref_image_keys.\ + All boxes of box_keys are attached to box_ref_image_keys." + ) + self.box_ref_image_keys = box_ref_image_keys + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + spatial_dims = len(d[self.box_ref_image_keys].shape) - 1 + self.converter = StandardizeEmptyBox(spatial_dims=spatial_dims) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + return dict(data) + + class ConvertBoxModed(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ConvertBoxMode`. @@ -1353,3 +1401,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch RandCropBoxByPosNegLabelD = RandCropBoxByPosNegLabelDict = RandCropBoxByPosNegLabeld RotateBox90D = RotateBox90Dict = RotateBox90d RandRotateBox90D = RandRotateBox90Dict = RandRotateBox90d +StandardizeEmptyBoxD = StandardizeEmptyBoxDict = StandardizeEmptyBoxd diff --git a/monai/apps/detection/utils/detector_utils.py b/monai/apps/detection/utils/detector_utils.py index 493ce5b216..a687476996 100644 --- a/monai/apps/detection/utils/detector_utils.py +++ b/monai/apps/detection/utils/detector_utils.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from collections.abc import Sequence from typing import Any @@ -18,6 +19,7 @@ import torch.nn.functional as F from torch import Tensor +from monai.data.box_utils import standardize_empty_box from monai.transforms.croppad.array import SpatialPad from monai.transforms.utils import compute_divisible_spatial_size, convert_pad_mode from monai.utils import PytorchPadMode, ensure_tuple_rep @@ -56,7 +58,7 @@ def check_training_targets( spatial_dims: int, target_label_key: str, target_box_key: str, -) -> None: +) -> list[dict[str, Tensor]]: """ Validate the input images/targets during training (raise a `ValueError` if invalid). @@ -75,7 +77,8 @@ def check_training_targets( if len(input_images) != len(targets): raise ValueError(f"len(input_images) should equal to len(targets), got {len(input_images)}, {len(targets)}.") - for target in targets: + for i in range(len(targets)): + target = targets[i] if (target_label_key not in target.keys()) or (target_box_key not in target.keys()): raise ValueError( f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}." @@ -85,10 +88,24 @@ def check_training_targets( if not isinstance(boxes, torch.Tensor): raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") if len(boxes.shape) != 2 or boxes.shape[-1] != 2 * spatial_dims: - raise ValueError( - f"Expected target boxes to be a tensor " f"of shape [N, {2* spatial_dims}], got {boxes.shape}." - ) - return + if boxes.numel() == 0: + warnings.warn( + f"Warning: Given target boxes has shape of {boxes.shape}. " + f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])." + ) + else: + raise ValueError( + f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)." + ) + if not torch.is_floating_point(boxes): + raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.") + targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims) # type: ignore + + labels = target[target_label_key] + if torch.is_floating_point(labels): + warnings.warn(f"Warning: Given target labels is {labels.dtype}. The detector converted it to torch.long.") + targets[i][target_label_key] = labels.long() + return targets def pad_images( diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index b040119626..1010e10b2f 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -395,19 +395,41 @@ def get_spatial_dims( # Check the validity of each input and add its corresponding spatial_dims to spatial_dims_set if boxes is not None: - if int(boxes.shape[1]) not in [4, 6]: + if len(boxes.shape) != 2: + if boxes.shape[0] == 0: + raise ValueError( + f"Currently we support only boxes with shape [N,4] or [N,6], " + f"got boxes with shape {boxes.shape}. " + f"Please reshape it with boxes = torch.reshape(boxes, [0, 4]) or torch.reshape(boxes, [0, 6])." + ) + else: + raise ValueError( + f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}." + ) + if int(boxes.shape[1] / 2) not in SUPPORTED_SPATIAL_DIMS: raise ValueError( f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}." ) spatial_dims_set.add(int(boxes.shape[1] / 2)) if points is not None: + if len(points.shape) != 2: + if points.shape[0] == 0: + raise ValueError( + f"Currently we support only points with shape [N,2] or [N,3], " + f"got points with shape {points.shape}. " + f"Please reshape it with points = torch.reshape(points, [0, 2]) or torch.reshape(points, [0, 3])." + ) + else: + raise ValueError( + f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}." + ) if int(points.shape[1]) not in SUPPORTED_SPATIAL_DIMS: raise ValueError( - f"Currently we support only points with shape [N,2] or [N,3], got boxes with shape {points.shape}." + f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}." ) spatial_dims_set.add(int(points.shape[1])) if corners is not None: - if len(corners) not in [4, 6]: + if len(corners) // 2 not in SUPPORTED_SPATIAL_DIMS: raise ValueError( f"Currently we support only boxes with shape [N,4] or [N,6], got box corner tuple with length {len(corners)}." ) @@ -494,6 +516,33 @@ def get_boxmode(mode: str | BoxMode | type[BoxMode] | None = None, *args, **kwar return StandardMode(*args, **kwargs) +def standardize_empty_box(boxes: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: + """ + When boxes are empty, this function standardize it to shape of (0,4) or (0,6). + + Args: + boxes: bounding boxes, Nx4 or Nx6 or empty torch tensor or ndarray + spatial_dims: number of spatial dimensions of the bounding boxes. + + Returns: + bounding boxes with shape (N,4) or (N,6), N can be 0. + + Example: + .. code-block:: python + + boxes = torch.ones(0,) + standardize_empty_box(boxes, 3) + """ + # convert numpy to tensor if needed + boxes_t, *_ = convert_data_type(boxes, torch.Tensor) + # handle empty box + if boxes_t.shape[0] == 0: + boxes_t = torch.reshape(boxes_t, [0, spatial_dims * 2]) + # convert tensor back to numpy if needed + boxes_dst, *_ = convert_to_dst_type(src=boxes_t, dst=boxes) + return boxes_dst + + def convert_box_mode( boxes: NdarrayOrTensor, src_mode: str | BoxMode | type[BoxMode] | None = None, @@ -522,6 +571,10 @@ def convert_box_mode( convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode) convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode()) """ + # handle empty box + if boxes.shape[0] == 0: + return boxes + src_boxmode = get_boxmode(src_mode) dst_boxmode = get_boxmode(dst_mode) diff --git a/tests/test_retinanet_detector.py b/tests/test_retinanet_detector.py index a5a4001f5c..f67430b0cd 100644 --- a/tests/test_retinanet_detector.py +++ b/tests/test_retinanet_detector.py @@ -134,20 +134,21 @@ def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape): detector.set_atss_matcher() detector.set_hard_negative_sampler(10, 0.5) - gt_box_start = torch.randint(2, (3, input_param["spatial_dims"])).to(torch.float16) - gt_box_end = gt_box_start + torch.randint(1, 10, (3, input_param["spatial_dims"])) - one_target = { - "boxes": torch.cat((gt_box_start, gt_box_end), dim=1), - "labels": torch.randint(input_param["num_classes"], (3,)), - } - with train_mode(detector): - input_data = torch.randn(input_shape) - targets = [one_target] * len(input_data) - result = detector.forward(input_data, targets) - - input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))] - targets = [one_target] * len(input_data) - result = detector.forward(input_data, targets) + for num_gt_box in [0, 3]: # test for both empty and non-empty boxes + gt_box_start = torch.randint(2, (num_gt_box, input_param["spatial_dims"])).to(torch.float16) + gt_box_end = gt_box_start + torch.randint(1, 10, (num_gt_box, input_param["spatial_dims"])) + one_target = { + "boxes": torch.cat((gt_box_start, gt_box_end), dim=1), + "labels": torch.randint(input_param["num_classes"], (num_gt_box,)), + } + with train_mode(detector): + input_data = torch.randn(input_shape) + targets = [one_target] * len(input_data) + result = detector.forward(input_data, targets) + + input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))] + targets = [one_target] * len(input_data) + result = detector.forward(input_data, targets) @parameterized.expand(TEST_CASES) def test_naive_retina_detector_shape(self, input_param, input_shape):