diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 12d95c583..06a3173d4 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -388,7 +388,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4) batch_idx, *_ = torch.where(valid_mask) - nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou) + nms_idx = batched_nms(boxes=valid_box, scores=valid_con, idxs=batch_idx, iou_threshold=nms_cfg.min_iou) predicts_nms = [] for idx in range(cls_dist.size(0)): instance_idx = nms_idx[idx == batch_idx[nms_idx]]