From f82d311d9cc439f1132f982b5476da04beecf8f4 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 21 Apr 2021 11:36:22 -0600 Subject: [PATCH 1/2] [ONNX] Support NMS Center Box --- python/tvm/relay/frontend/onnx.py | 14 ++++++++++---- tests/python/frontend/onnx/test_forward.py | 1 - 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cc66cd3c6fe8..03f4371167a3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2506,11 +2506,17 @@ def _impl_v10(cls, inputs, attr, params): iou_threshold = inputs[3] score_threshold = inputs[4] + boxes_dtype = infer_type(boxes).checked_type.dtype + if "center_point_box" in attr: - if attr["center_point_box"] != 0: - raise NotImplementedError( - "Only support center_point_box = 0 in ONNX NonMaxSuprresion" - ) + xc, yc, w, h = _op.split(boxes, 4, axis=2) + half_w = w / _expr.const(2.0, boxes_dtype) + half_h = h / _expr.const(2.0, boxes_dtype) + x1 = xc - half_w + x2 = xc + half_w + y1 = yc - half_h + y2 = yc + half_h + boxes = _op.concatenate([y1, x1, y2, x2], axis=2) if iou_threshold is None: iou_threshold = _expr.const(0.0, dtype="float32") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 595a3b1c89b3..2bc390fcab15 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4177,7 +4177,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", "test_mvn/", - "test_nonmaxsuppression_center_point_box_format/", "test_qlinearconv/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", From fdf8d53b8aa38dcca608fc055f577d10b6c13007 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 21 Apr 2021 11:40:04 -0600 Subject: [PATCH 2/2] fix silly mistake in contional --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 03f4371167a3..21cf217183c6 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2508,7 +2508,7 @@ def _impl_v10(cls, inputs, attr, params): boxes_dtype = infer_type(boxes).checked_type.dtype - if "center_point_box" in attr: + if attr.get("center_point_box", 0) != 0: xc, yc, w, h = _op.split(boxes, 4, axis=2) half_w = w / _expr.const(2.0, boxes_dtype) half_h = h / _expr.const(2.0, boxes_dtype)