diff --git a/tests/test_utils/test_bounding_box_utils.py b/tests/test_utils/test_bounding_box_utils.py index 58a9a9172..efa3c8ed4 100644 --- a/tests/test_utils/test_bounding_box_utils.py +++ b/tests/test_utils/test_bounding_box_utils.py @@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config): def test_bbox_nms(): - cls_dist = tensor( - [[[0.1, 0.7, 0.2], [0.6, 0.3, 0.1]], [[0.4, 0.4, 0.2], [0.5, 0.4, 0.1]]] # Example class distribution + cls_dist = torch.tensor( + [ + [ + [0.7, 0.1, 0.2], # High confidence, class 0 + [0.3, 0.6, 0.1], # High confidence, class 1 + [-3.0, -2.0, -1.0], # low confidence, class 2 + [0.6, 0.2, 0.2], # Medium confidence, class 0 + ], + [ + [0.55, 0.25, 0.2], # Medium confidence, class 0 + [-4.0, -0.5, -2.0], # low confidence, class 1 + [0.15, 0.2, 0.65], # Medium confidence, class 2 + [0.8, 0.1, 0.1], # High confidence, class 0 + ], + ], + dtype=float32, ) - bbox = tensor( - [[[50, 50, 100, 100], [60, 60, 110, 110]], [[40, 40, 90, 90], [70, 70, 120, 120]]], # Example bounding boxes + + bbox = torch.tensor( + [ + [ + [0, 0, 160, 120], # Overlaps with box 4 + [160, 120, 320, 240], + [0, 120, 160, 240], + [16, 12, 176, 132], + ], + [ + [0, 0, 160, 120], # Overlaps with box 4 + [160, 120, 320, 240], + [0, 120, 160, 240], + [16, 12, 176, 132], + ], + ], dtype=float32, ) + nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5) - expected_output = [ - tensor( + # Batch 1: + # - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out + # - box 2 is kept with class 1 + # - box 3 is rejected by the confidence filter + # Batch 2: + # - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out + # - box 2 is rejected by the confidence filter + # - box 3 is kept with class 2 + expected_output = torch.tensor( + [ [ - [1.0000, 50.0000, 50.0000, 100.0000, 100.0000, 0.6682], - [0.0000, 60.0000, 60.0000, 110.0000, 110.0000, 0.6457], - ] - ) - ] + [0.0, 0.0, 0.0, 160.0, 120.0, 0.6682], + [1.0, 160.0, 120.0, 320.0, 240.0, 0.6457], + ], + [ + [0.0, 16.0, 12.0, 176.0, 132.0, 0.6900], + [2.0, 0.0, 120.0, 160.0, 240.0, 0.6570], + ], + ] + ) output = bbox_nms(cls_dist, bbox, nms_cfg) diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 47c1636a6..3485fa85d 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -387,7 +387,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4) batch_idx, *_ = torch.where(valid_mask) - nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou) + nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou) predicts_nms = [] for idx in range(cls_dist.size(0)): instance_idx = nms_idx[idx == batch_idx[nms_idx]]