diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 60ee4cb88e43..5408582c8356 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -40,6 +40,22 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in non_maximum_suppression operators */ +struct NMSAttrs : public tvm::AttrsNode{ + double overlap_threshold; + bool force_suppress; + int topk; + + TVM_DECLARE_ATTRS(NMSAttrs, "relay.attrs.NMSAttrs") { + TVM_ATTR_FIELD(overlap_threshold).set_default(0.5) + .describe("Non-maximum suppression threshold."); + TVM_ATTR_FIELD(force_suppress).set_default(false) + .describe("Suppress all detections regardless of class_id."); + TVM_ATTR_FIELD(topk).set_default(-1) + .describe("Keep maximum top k detections before nms, -1 for no limit."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_VISION_H_ diff --git a/python/tvm/relay/op/vision/__init__.py b/python/tvm/relay/op/vision/__init__.py index b3010d2d5310..9ecd8a84770a 100644 --- a/python/tvm/relay/op/vision/__init__.py +++ b/python/tvm/relay/op/vision/__init__.py @@ -3,3 +3,4 @@ from __future__ import absolute_import as _abs from .multibox import * +from .nms import * diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py new file mode 100644 index 000000000000..8035e3030b17 --- /dev/null +++ b/python/tvm/relay/op/vision/nms.py @@ -0,0 +1,36 @@ +"""Non-maximum suppression operations.""" +from __future__ import absolute_import as _abs +from . import _make + +def nms(data, + valid_count, + overlap_threshold=0.5, + force_suppress=False, + topk=-1): + """Non-maximum suppression operator for object detection. + + Parameters + ---------- + data : relay.Expr + 3-D tensor with shape [batch_size, num_anchors, 6]. + The last dimension should be in format of + [class_id, score, box_left, box_top, box_right, box_bottom]. + + valid_count : relay.Expr + 1-D tensor for valid number of boxes. + + overlap_threshold : float, optional + Non-maximum suppression threshold. + + force_suppress : bool, optional + Suppress all detections regardless of class_id. + + topk : int, optional + Keep maximum top k detections before nms, -1 for no limit. + + Returns + ------- + out : relay.Expr + 3-D tensor with shape [batch_size, num_anchors, 6]. + """ + return _make.nms(data, valid_count, overlap_threshold, force_suppress, topk) diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index ce069a78186b..e347e544e4f9 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -5,7 +5,6 @@ */ #include #include -#include namespace tvm { namespace relay { @@ -66,7 +65,7 @@ RELAY_REGISTER_OP("vision.multibox_prior") .set_attrs_type_key("relay.attrs.MultiBoxPriorAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") -.set_support_level(4) +.set_support_level(5) .add_type_rel("MultiBoxPrior", MultiboxPriorRel); } // namespace relay diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc new file mode 100644 index 000000000000..3e3f73bc6cb4 --- /dev/null +++ b/src/relay/op/vision/nms.cc @@ -0,0 +1,62 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file nms.cc + * \brief Non-maximum suppression operators + */ +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(NMSAttrs); + +bool NMSRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* valid_count = types[1].as(); + const auto& dshape = data->shape; + const auto& vshape = valid_count->shape; + CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; + CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D."; + + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype)); + return true; +} + + +Expr MakeNMS(Expr data, + Expr valid_count, + double overlap_threshold, + bool force_suppress, + int topk) { + auto attrs = make_node(); + attrs->overlap_threshold = overlap_threshold; + attrs->force_suppress = force_suppress; + attrs->topk = topk; + static const Op& op = Op::Get("vision.nms"); + return CallNode::make(op, {data, valid_count}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.vision._make.nms") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeNMS, args, rv); +}); + + +RELAY_REGISTER_OP("vision.nms") +.describe(R"doc("Non-maximum suppression." +)doc" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "Input data.") +.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") +.set_support_level(5) +.add_type_rel("NMS", NMSRel); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 4e554cd0cf81..0bd7a4816a1b 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -18,7 +18,6 @@ def test_resize_infer_type(): assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") - def test_multibox_prior(): sizes = (0.3, 1.5, 0.7) ratios = (1.3, 2.4) @@ -44,6 +43,36 @@ def test_multibox_prior(): (1, h * w, 4), "float32") +def test_nms(): + num_anchors = 60 + + overlap_threshold = 0.5 + force_suppress = True + nms_topk = 10 + + n = tvm.var("n") + x0 = relay.var("x0", relay.ty.TensorType((n, num_anchors, 6), "float32")) + x1 = relay.var("x1", relay.ty.TensorType((n,), "int")) + + z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, nms_topk) + + assert "overlap_threshold" in z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType( + (n, num_anchors, 6), "float32") + + n = tvm.var("n") + x0 = relay.var("x0", relay.ty.TensorType((n, num_anchors, 6), "float32")) + x1 = relay.var("x1", relay.ty.TensorType((n,), "int")) + + z = relay.vision.nms(x0, x1) + + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType( + (n, num_anchors, 6), "float32") + + if __name__ == "__main__": test_resize_infer_type() test_multibox_prior() + test_nms()