diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index df059a6238e1..20b80f33a2a3 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -58,19 +58,42 @@ struct MultiBoxTransformLocAttrs } }; -/*! \brief Attributes used in non_maximum_suppression operators */ -struct NMSAttrs : public tvm::AttrsNode{ - double overlap_threshold; +/*! \brief Attributes used in get_valid_counts operator */ +struct GetValidCountsAttrs : public tvm::AttrsNode { + double score_threshold; + + TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { + TVM_ATTR_FIELD(score_threshold).set_default(0.0) + .describe("Lower limit of score for valid bounding boxes."); + } +}; + +/*! \brief Attributes used in non_maximum_suppression operator */ +struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + int max_output_size; + double iou_threshold; bool force_suppress; - int topk; - - TVM_DECLARE_ATTRS(NMSAttrs, "relay.attrs.NMSAttrs") { - TVM_ATTR_FIELD(overlap_threshold).set_default(0.5) - .describe("Non-maximum suppression threshold."); - TVM_ATTR_FIELD(force_suppress).set_default(false) - .describe("Suppress all detections regardless of class_id."); - TVM_ATTR_FIELD(topk).set_default(-1) - .describe("Keep maximum top k detections before nms, -1 for no limit."); + int top_k; + int id_index; + bool return_indices; + bool invalid_to_bottom; + + TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") { + TVM_ATTR_FIELD(max_output_size).set_default(-1) + .describe("Max number of output valid boxes for each instance." + "By default all valid boxes are returned."); + TVM_ATTR_FIELD(iou_threshold).set_default(0.5) + .describe("Non-maximum suppression threshold."); + TVM_ATTR_FIELD(force_suppress).set_default(false) + .describe("Suppress all detections regardless of class_id."); + TVM_ATTR_FIELD(top_k).set_default(-1) + .describe("Keep maximum top k detections before nms, -1 for no limit."); + TVM_ATTR_FIELD(id_index).set_default(0) + .describe("Axis index of id."); + TVM_ATTR_FIELD(return_indices).set_default(true) + .describe("Whether to return box indices in input data."); + TVM_ATTR_FIELD(invalid_to_bottom).set_default(false) + .describe("Whether to move all invalid bounding boxes to the bottom."); } }; diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index 143a9548f18a..578f928c5b9f 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -443,17 +443,30 @@ struct MultiBoxTransformLocParam : public dmlc::Parameter { - float nms_threshold; +struct NonMaximumSuppressionParam : public dmlc::Parameter { + bool return_indices; + float iou_threshold; bool force_suppress; - int nms_topk; - DMLC_DECLARE_PARAMETER(NMSParam) { - DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) + int top_k; + int id_index; + int max_output_size; + bool invalid_to_bottom; + DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) { + DMLC_DECLARE_FIELD(max_output_size).set_default(-1) + .describe("Max number of output valid boxes for each instance." + "By default all valid boxes are returned."); + DMLC_DECLARE_FIELD(iou_threshold).set_default(0.5) .describe("Non-maximum suppression threshold."); DMLC_DECLARE_FIELD(force_suppress).set_default(false) - .describe("Suppress all detections regardless of class_id."); - DMLC_DECLARE_FIELD(nms_topk).set_default(-1) - .describe("Keep maximum top k detections before nms, -1 for no limit."); + .describe("Suppress all detections regardless of class_id."); + DMLC_DECLARE_FIELD(top_k).set_default(-1) + .describe("Keep maximum top k detections before nms, -1 for no limit."); + DMLC_DECLARE_FIELD(id_index).set_default(0) + .describe("Axis index of id."); + DMLC_DECLARE_FIELD(return_indices).set_default(true) + .describe("Whether to return box indices in input data."); + DMLC_DECLARE_FIELD(invalid_to_bottom).set_default(false) + .describe("Whether to move all invalid bounding boxes to the bottom."); } }; diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 179e1126fd4d..47d7ede96e5f 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -245,11 +245,11 @@ def _contrib_multibox_detection(inputs, attrs): if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2) nms_topk = attrs.get('nms_topk') or -1 new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances} - new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress, - 'nms_topk': int(nms_topk)} + new_attrs1 = {'return_indices': False, 'iou_threshold': float(nms_threshold), + 'force_suppress': force_suppress, 'top_k': int(nms_topk)} data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], inputs[2], **new_attrs0) - return _get_nnvm_op('nms')(data, valid_count, **new_attrs1) + return _get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1) def _elemwise_sum(inputs, _): new_attrs = {'num_args':len(inputs)} diff --git a/nnvm/python/nnvm/top/vision.py b/nnvm/python/nnvm/top/vision.py index 1b20baab47c3..ab32838e10ff 100644 --- a/nnvm/python/nnvm/top/vision.py +++ b/nnvm/python/nnvm/top/vision.py @@ -61,20 +61,25 @@ def compute_multibox_transform_loc(attrs, inputs, _): reg.register_pattern("multibox_detection", OpPattern.OPAQUE) # non-maximum suppression -@reg.register_schedule("nms") +@reg.register_schedule("non_max_suppression") def schedule_nms(_, outs, target): - """Schedule definition of nms""" + """Schedule definition of non_max_suppression""" with tvm.target.create(target): return topi.generic.schedule_nms(outs) -@reg.register_compute("nms") +@reg.register_compute("non_max_suppression") def compute_nms(attrs, inputs, _): - """Compute definition of nms""" - nms_threshold = attrs.get_float('nms_threshold') + """Compute definition of non_max_suppression""" + return_indices = attrs.get_bool('return_indices') + max_output_size = attrs.get_int('max_output_size') + iou_threshold = attrs.get_float('iou_threshold') force_suppress = attrs.get_bool('force_suppress') - nms_topk = attrs.get_int('nms_topk') + top_k = attrs.get_int('top_k') + id_index = attrs.get_int('id_index') + invalid_to_bottom = attrs.get_bool('invalid_to_bottom') - return topi.vision.nms(inputs[0], inputs[1], nms_threshold, - force_suppress, nms_topk) + return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size, + iou_threshold, force_suppress, top_k, + id_index, return_indices, invalid_to_bottom) -reg.register_pattern("nms", OpPattern.OPAQUE) +reg.register_pattern("non_max_suppression", OpPattern.OPAQUE) diff --git a/nnvm/src/top/vision/nms.cc b/nnvm/src/top/vision/nms.cc index 2680b894255b..e69a7cb2f036 100644 --- a/nnvm/src/top/vision/nms.cc +++ b/nnvm/src/top/vision/nms.cc @@ -19,11 +19,13 @@ using compiler::FTVMCompute; using tvm::Tensor; using tvm::Array; -DMLC_REGISTER_PARAMETER(NMSParam); +DMLC_REGISTER_PARAMETER(NonMaximumSuppressionParam); bool NMSShape(const NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { + const NonMaximumSuppressionParam& param = + nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]"; TShape dshape = in_attrs->at(0); TShape vshape = in_attrs->at(1); @@ -33,7 +35,14 @@ bool NMSShape(const NodeAttrs& attrs, "(batch_size, num_anchors, 6)."; CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch."; out_attrs->clear(); - NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape); + if (param.return_indices) { + TShape oshape = TShape(2); + oshape[0] = dshape[0]; + oshape[1] = dshape[1]; + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); + } else { + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape); + } return true; } @@ -56,15 +65,15 @@ inline bool NMSInferLayout(const NodeAttrs& attrs, return true; } -NNVM_REGISTER_OP(nms) +NNVM_REGISTER_OP(non_max_suppression) .describe(R"doc("Non-maximum suppression." )doc" NNVM_ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) -.set_attr_parser(ParamParser) +.set_attr_parser(ParamParser) .set_attr("FGetAttrDict", - ParamGetAttrDict) -.add_arguments(NMSParam::__FIELDS__()) + ParamGetAttrDict) +.add_arguments(NonMaximumSuppressionParam::__FIELDS__()) .add_argument("data", "Tensor", "Input data.") .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") .set_attr("FListInputNames", [](const NodeAttrs& attrs) { diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index fc4e62fb7156..6a42047151e5 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -550,7 +550,7 @@ def test_multibox_transform_loc(): anchors = sym.Variable("anchors") transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds, anchor=anchors) - out = sym.nms(data=transform_loc_data, valid_count=valid_count) + out = sym.non_max_suppression(data=transform_loc_data, valid_count=valid_count, return_indices=False) # Manually create test case np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]]) @@ -573,22 +573,22 @@ def test_multibox_transform_loc(): out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype)) tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5) -def test_nms(): +def test_non_max_suppression(): dshape = (1, 5, 6) data = sym.Variable("data") valid_count = sym.Variable("valid_count", dtype="int32") - nms_threshold = 0.7 + iou_threshold = 0.7 force_suppress = True - nms_topk = 2 - out = sym.nms(data=data, valid_count=valid_count, nms_threshold=nms_threshold, - force_suppress=force_suppress, nms_topk=nms_topk) + top_k = 2 + out = sym.non_max_suppression(data=data, valid_count=valid_count, return_indices=False, + iou_threshold=iou_threshold, force_suppress=force_suppress, top_k=top_k) np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) target = "llvm" @@ -726,7 +726,7 @@ def test_argmax(): test_flip() test_multibox_prior() test_multibox_transform_loc() - test_nms() + test_non_max_suppression() test_slice_like() test_where() test_argmax() diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index e046f39f02ca..581ae75a4bbc 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -315,4 +315,3 @@ def test_forward_minimum(): test_forward_slice() test_forward_maximum() test_forward_minimum() - diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 2e0ccd07fdc1..cdfa75e50419 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -324,13 +324,14 @@ def _mx_multibox_detection(inputs, attrs): 0.2, 0.2)) new_attrs1 = {} - new_attrs1["overlap_threshold"] = attrs.get_float("nms_threshold", 0.5) + new_attrs1["return_indices"] = False + new_attrs1["iou_threshold"] = attrs.get_float("nms_threshold", 0.5) new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False) - new_attrs1["topk"] = attrs.get_int("nms_topk", -1) + new_attrs1["top_k"] = attrs.get_int("nms_topk", -1) ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1], inputs[2], **new_attrs0) - return _op.vision.nms(ret[0], ret[1], **new_attrs1) + return _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1) def _mx_batch_dot(inputs, attrs): @@ -380,6 +381,49 @@ def _mx_proposal(inputs, attrs): return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs) +def _mx_box_nms(inputs, attrs): + force_suppress = attrs.get_bool("force_suppress", False) + iou_thresh = attrs.get_float('overlap_thresh', 0.5) + top_k = attrs.get_int('topk', -1) + valid_thresh = attrs.get_float('valid_thresh', 0) + coord_start = attrs.get_int('coord_start', 2) + score_index = attrs.get_int('score_index', 1) + id_index = attrs.get_int('id_index', -1) + in_format = attrs.get_str('in_format', 'corner') + out_format = attrs.get_str('out_format', 'corner') + if coord_start != 2: + raise RuntimeError('coord_start %s is not supported.' % coord_start) + if score_index != 1: + raise RuntimeError('score_index %s is not supported.' % score_index) + if id_index != -1 and int(id_index) != 0: + raise RuntimeError('id_index %s is not supported.' % id_index) + if in_format != 'corner': + raise RuntimeError('in_format %s is not supported.' % in_format) + if out_format != 'corner': + raise RuntimeError('out_format %s is not supported.' % out_format) + + ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) + nms_out = _op.vision.non_max_suppression(ret[1], + ret[0], + iou_threshold=iou_thresh, + force_suppress=force_suppress, + top_k=top_k, + id_index=id_index, + return_indices=False, + invalid_to_bottom=True) + return nms_out + + +def _mx_l2_normalize(inputs, attrs): + new_attrs = {} + mode = attrs.get_str('mode', 'instance') + if mode != 'channel': + raise RuntimeError('mode %s is not supported.' % mode) + new_attrs['eps'] = attrs.get_float('eps', 1e-10) + new_attrs['axis'] = [1] + return _op.nn.l2_normalize(inputs[0], **new_attrs) + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -478,6 +522,7 @@ def _mx_proposal(inputs, attrs): "BatchNorm" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm, "LRN" : _mx_lrn, + "L2Normalization" : _mx_l2_normalize, "slice" : _mx_slice, "slice_like" : _mx_slice_like, "slice_axis" : _mx_slice_axis, @@ -498,6 +543,7 @@ def _mx_proposal(inputs, attrs): "_contrib_ROIAlign" : _mx_roi_align, "_contrib_Proposal" : _mx_proposal, "_contrib_MultiProposal" : _mx_proposal, + "_contrib_box_nms" : _mx_box_nms, # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # @@ -640,6 +686,8 @@ def from_mxnet(symbol, params[k] = _nd.array(v.data().asnumpy()) data = mx.sym.Variable("data") sym = symbol(data) + if isinstance(sym, (list, tuple)): + sym = mx.sym.Group(sym) shape, dtype = _update_shape_dtype(shape, dtype, params) sym = _from_mxnet_impl(sym, shape, dtype) elif isinstance(symbol, mx.gluon.Block): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 845ee02b0582..725f57f54bd8 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -456,7 +456,7 @@ def strided_slice(data, begin, end, strides=None): The indices to begin with in the slicing. end: list of int - Indicies indicating end of the slice. + Indices indicating end of the slice. strides: list of int, optional Specifies the stride values, it can be negative in that case, diff --git a/python/tvm/relay/op/vision/__init__.py b/python/tvm/relay/op/vision/__init__.py index 10cf6c2fd3ee..0cee4e4faeec 100644 --- a/python/tvm/relay/op/vision/__init__.py +++ b/python/tvm/relay/op/vision/__init__.py @@ -6,6 +6,6 @@ from .nms import * from .rcnn import * from .yolo import * -from . import _multibox from . import _rcnn from . import _yolo +from . import _vision diff --git a/python/tvm/relay/op/vision/_multibox.py b/python/tvm/relay/op/vision/_vision.py similarity index 62% rename from python/tvm/relay/op/vision/_multibox.py rename to python/tvm/relay/op/vision/_vision.py index e9ef43f7e06f..c887076e6af8 100644 --- a/python/tvm/relay/op/vision/_multibox.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -54,24 +54,46 @@ def compute_multibox_transform_loc(attrs, inputs, _, target): reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE) +# Get counts of valid boxes +@reg.register_schedule("vision.get_valid_counts") +def schedule_get_valid_counts(_, outs, target): + """Schedule definition of get_valid_counts""" + with target: + return topi.generic.schedule_get_valid_counts(outs) + + +@reg.register_compute("vision.get_valid_counts") +def compute_get_valid_counts(attrs, inputs, _, target): + """Compute definition of get_valid_counts""" + score_threshold = get_const_float(attrs.score_threshold) + return topi.vision.get_valid_counts(inputs[0], score_threshold) + +reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE) + + # non-maximum suppression -@reg.register_schedule("vision.nms") +@reg.register_schedule("vision.non_max_suppression") def schedule_nms(_, outs, target): """Schedule definition of nms""" with target: return topi.generic.schedule_nms(outs) -@reg.register_compute("vision.nms") +@reg.register_compute("vision.non_max_suppression") def compute_nms(attrs, inputs, _, target): """Compute definition of nms""" - overlap_threshold = get_const_float(attrs.overlap_threshold) + return_indices = bool(get_const_int(attrs.return_indices)) + max_output_size = get_const_int(attrs.max_output_size) + iou_threshold = get_const_float(attrs.iou_threshold) force_suppress = bool(get_const_int(attrs.force_suppress)) - topk = get_const_int(attrs.topk) + top_k = get_const_int(attrs.top_k) + id_index = get_const_int(attrs.id_index) + invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom)) return [ - topi.vision.nms(inputs[0], inputs[1], overlap_threshold, - force_suppress, topk) + topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size, + iou_threshold, force_suppress, top_k, + id_index, return_indices, invalid_to_bottom) ] -reg.register_pattern("vision.nms", OpPattern.OPAQUE) +reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 8035e3030b17..0124ee29ab9e 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -1,12 +1,41 @@ """Non-maximum suppression operations.""" from __future__ import absolute_import as _abs from . import _make +from ...expr import TupleWrapper -def nms(data, - valid_count, - overlap_threshold=0.5, - force_suppress=False, - topk=-1): +def get_valid_counts(data, + score_threshold): + """Get valid count of bounding boxes given a score threshold. + Also moves valid boxes to the top of input data. + + Parameters + ---------- + data : relay.Expr + Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + + score_threshold : optional, float + Lower limit of score for valid bounding boxes. + + Returns + ------- + valid_count : relay.Expr + 1-D tensor for valid number of boxes. + + out_tensor : relay.Expr + Rearranged data tensor. + """ + return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2) + + +def non_max_suppression(data, + valid_count, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + id_index=0, + return_indices=True, + invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters @@ -19,18 +48,33 @@ def nms(data, valid_count : relay.Expr 1-D tensor for valid number of boxes. - overlap_threshold : float, optional + max_output_size : int, optional + Max number of output valid boxes for each instance. + By default all valid boxes are returned. + + iou_threshold : float, optional Non-maximum suppression threshold. force_suppress : bool, optional Suppress all detections regardless of class_id. - topk : int, optional + top_k : int, optional Keep maximum top k detections before nms, -1 for no limit. + id_index : int, optional + index of the class categories, -1 to disable. + + return_indices : bool, optional + Whether to return box indices in input data. + + invalid_to_bottom : bool, optional + Whether to move all valid bounding boxes to the top. + Returns ------- out : relay.Expr 3-D tensor with shape [batch_size, num_anchors, 6]. """ - return _make.nms(data, valid_count, overlap_threshold, force_suppress, topk) + return _make.non_max_suppression(data, valid_count, max_output_size, + iou_threshold, force_suppress, top_k, + id_index, return_indices, invalid_to_bottom) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index de3ac03977f4..0c26e3da742e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1347,6 +1347,16 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_attr("TOpPattern", kBroadcast); +// Adapter function to make int array. +Array GetIntArray(Array arr) { + for (size_t i = 0; i < arr.size(); ++i) { + CHECK(!arr[i].defined() || arr[i].as()) + << "Expect an int array"; + } + return Array(arr.node_); +} + + // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); bool StridedSliceRel(const Array& types, @@ -1701,15 +1711,6 @@ Expr MakeSliceLike(Expr data, return CallNode::make(op, {data, shape_like}, Attrs(attrs), {}); } -// Adapter function to make int array. -Array GetIntArray(Array arr) { - for (size_t i = 0; i < arr.size(); ++i) { - CHECK(!arr[i].defined() || arr[i].as()) - << "Expect an int array"; - } - return Array(arr.node_); -} - Array SliceLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index 55db8862e849..04f105c44744 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -70,8 +70,10 @@ RELAY_REGISTER_OP("vision.multibox_prior") TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs); -bool MultiBoxTransformLocRel(const Array& types, int num_inputs, - const Attrs& attrs, const TypeReporter& reporter) { +bool MultiBoxTransformLocRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* cls_prob = types[0].as(); diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 3e3f73bc6cb4..6a94da032196 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -9,7 +9,54 @@ namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(NMSAttrs); +TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs); + +bool GetValidCountRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + const auto& dshape = data->shape; + CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; + + std::vector oshape({data->shape[0]}); + std::vector fields; + fields.push_back(TensorTypeNode::make(oshape, Int(32))); + fields.push_back(TensorTypeNode::make(data->shape, data->dtype)); + + // assign output type + reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + return true; +} + +Expr MakeGetValidCounts(Expr data, + double score_threshold) { + auto attrs = make_node(); + attrs->score_threshold = score_threshold; + static const Op& op = Op::Get("vision.get_valid_counts"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.vision._make.get_valid_counts") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeGetValidCounts, args, rv); +}); + + +RELAY_REGISTER_OP("vision.get_valid_counts") +.describe(R"doc(Get valid count of bounding boxes given +a score threshold. Also moves valid boxes to the top of +input data. +)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "Input data.") +.set_support_level(5) +.add_type_rel("GetValidCount", GetValidCountRel); + + +TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); bool NMSRel(const Array& types, int num_inputs, @@ -18,39 +65,56 @@ bool NMSRel(const Array& types, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* valid_count = types[1].as(); + const NonMaximumSuppressionAttrs* param = + attrs.as(); const auto& dshape = data->shape; const auto& vshape = valid_count->shape; CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D."; // assign output type - reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype)); + if (param->return_indices) { + std::vector oshape({dshape[0], dshape[1]}); + reporter->Assign(types[2], TensorTypeNode::make(oshape, Int(32))); + } else { + reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype)); + } return true; } Expr MakeNMS(Expr data, Expr valid_count, - double overlap_threshold, + int max_output_size, + double iou_threshold, bool force_suppress, - int topk) { - auto attrs = make_node(); - attrs->overlap_threshold = overlap_threshold; + int top_k, + int id_index, + bool return_indices, + bool invalid_to_bottom) { + auto attrs = make_node(); + attrs->max_output_size = max_output_size; + attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; - attrs->topk = topk; - static const Op& op = Op::Get("vision.nms"); + attrs->top_k = top_k; + attrs->id_index = id_index; + attrs->return_indices = return_indices; + attrs->invalid_to_bottom = invalid_to_bottom; + static const Op& op = Op::Get("vision.non_max_suppression"); return CallNode::make(op, {data, valid_count}, Attrs(attrs), {}); } -TVM_REGISTER_API("relay.op.vision._make.nms") +TVM_REGISTER_API("relay.op.vision._make.non_max_suppression") .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeNMS, args, rv); + runtime::detail::unpack_call(MakeNMS, args, rv); }); -RELAY_REGISTER_OP("vision.nms") -.describe(R"doc("Non-maximum suppression." +RELAY_REGISTER_OP("vision.non_max_suppression") +.describe(R"doc(Non-maximum suppression. The input boxes should +be in the format of [class_id, score, left, top, right, bottom]. +Set id_index to be -1 to ignore class_id axis. )doc" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("data", "Tensor", "Input data.") diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 2dfe20c503e6..4679876c181b 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -374,6 +374,11 @@ def verify(x_shape, y_shape, axes): verify((3, 4), (2, 3), (0)) verify((3, 4), (2, 3), (-1)) +def test_forward_l2_normalize(): + data = mx.sym.var('data') + mx_sym = mx.sym.L2Normalization(data, mode="channel") + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) + if __name__ == '__main__': test_forward_mlp() @@ -401,5 +406,6 @@ def verify(x_shape, y_shape, axes): test_forward_broadcast_ops() test_forward_elemwise_ops() test_forward_scalar_ops() - test_forward_slice_axis() test_forward_slice_like() + test_forward_slice_axis() + test_forward_l2_normalize() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 34285d2b18dd..7237cfbc3b87 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -2,6 +2,7 @@ """ import numpy as np import tvm +import topi.testing from tvm import relay from tvm.relay.testing import ctx_list import topi diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 003318f01a2f..eceedc760d4b 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -135,56 +135,107 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True) -def test_nms(): - def verify_nms(x0_data, x1_data, dshape, ref_res, valid_count, - overlap_threshold=0.5, force_suppress=False, topk=-1, +def test_get_valid_counts(): + def verify_get_valid_counts(dshape, score_threshold): + dtype = "float32" + batch_size, num_anchor, elem_length = dshape + np_data = np.random.uniform(size=dshape).astype(dtype) + np_out1 = np.zeros(shape=(batch_size,)) + np_out2 = np.zeros(shape=dshape).astype(dtype) + for i in range(batch_size): + np_out1[i] = 0 + inter_idx = 0 + for j in range(num_anchor): + score = np_data[i, j, 1] + if score >= score_threshold: + for k in range(elem_length): + np_out2[i, inter_idx, k] = np_data[i, j, k] + np_out1[i] += 1 + inter_idx += 1 + if j >= np_out1[i]: + for k in range(elem_length): + np_out2[i, j, k] = -1 + + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + z = relay.vision.get_valid_counts(x, score_threshold) + assert "score_threshold" in z.astext() + func = relay.Function([x], z.astuple()) + func = relay.ir_pass.infer_type(func) + ctx_list = [("llvm", tvm.cpu(0))] + for target, ctx in ctx_list: + intrp = relay.create_executor("debug", ctx=ctx, target=target) + out = intrp.evaluate(func)(np_data) + tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3) + + verify_get_valid_counts((1, 2500, 6), 0) + verify_get_valid_counts((1, 2500, 6), -1) + verify_get_valid_counts((3, 1000, 6), 0.55) + verify_get_valid_counts((16, 500, 6), 0.95) + + +def test_non_max_suppression(): + def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, + iou_threshold=0.5, force_suppress=False, top_k=-1, check_type_only=False): x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int")) - z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, topk) - assert "overlap_threshold" in z.astext() + z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False) + z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k) + assert "iou_threshold" in z.astext() + assert "iou_threshold" in z_indices.astext() zz = relay.ir_pass.infer_type(z) + zz_indices = relay.ir_pass.infer_type(z_indices) assert zz.checked_type == relay.ty.TensorType(dshape, "float32") + assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32") if check_type_only: return func = relay.Function([x0, x1], z) func = relay.ir_pass.infer_type(func) + func_indices = relay.Function([x0, x1], z_indices) + func_indices = relay.ir_pass.infer_type(func_indices) ctx_list = [("llvm", tvm.cpu(0))] for target, ctx in ctx_list: intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x0_data, x1_data) + op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + tvm.testing.assert_allclose(op_indices_res1.asnumpy(), ref_indices_res, rtol=1e-5) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res2 = intrp2.evaluate(func)(x0_data, x1_data) + op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) + tvm.testing.assert_allclose(op_indices_res2.asnumpy(), ref_indices_res, rtol=1e-5) np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) num_anchors = 5 dshape = (tvm.var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], - force_suppress=True, topk=2, check_type_only=True) + verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result, + force_suppress=True, top_k=2, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], - force_suppress=True, topk=2, check_type_only=False) + verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result, + force_suppress=True, top_k=2, check_type_only=False) np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [1, 0.7, 30, 60, 50, 80], [-1, 0.9, 35, 61, 52, 79], + [1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) + np_indices_result = np.array([[3, 0, 1, -1, -1]]) dshape = (tvm.var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], - check_type_only=True) + verify_nms(np_data, np_valid_count, dshape, np_result, + np_indices_result, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], - topk=3) + verify_nms(np_data, np_valid_count, dshape, np_result, + np_indices_result, top_k=3) def test_multibox_transform_loc(): @@ -226,7 +277,7 @@ def test_default_value(): assert ret.checked_type == ref_type - nms = relay.vision.nms(mtl[0], mtl[1]) + nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False) func = relay.Function([cls_prob, loc_pred, anchors], nms) func = relay.ir_pass.infer_type(func) ctx_list = [("llvm", tvm.cpu(0))] @@ -411,8 +462,9 @@ def verify_yolo_reorg(shape, stride): test_resize() test_multibox_prior() test_multibox_transform_loc() - test_nms() + test_get_valid_counts() test_roi_align() test_proposal() test_yolo_reorg_infer_shape() test_yolo_reorg() + test_non_max_suppression() diff --git a/topi/include/topi/nn/l2_normalize.h b/topi/include/topi/nn/l2_normalize.h index a9fd49cbee64..4f9bdb61ab70 100644 --- a/topi/include/topi/nn/l2_normalize.h +++ b/topi/include/topi/nn/l2_normalize.h @@ -30,7 +30,12 @@ inline Tensor l2_normalize(const Tensor& data, const Array& axis, std::string name = "tensor", std::string tag = "l2_normalize") { - CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input"; + for (size_t i = 0; i < axis.size(); ++i) { + int ax = topi::detail::GetConstInt(axis[i]); + CHECK_LT(ax, data->shape.size()) << + "Axis " << ax << " exceeds input data dim " << + data->shape.size(); + } auto input_shape = data->shape; Tensor dot_value = topi::power(data, static_cast(2.0)); Tensor sum_value = topi::sum(dot_value, axis, true); diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index e0d71559f1a0..5f79de25e835 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -1,10 +1,10 @@ -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument """Non-maximum suppression operator""" import math import tvm from tvm import api -from topi.vision import nms +from topi.vision import non_max_suppression from ..util import get_const_tuple def sort_ir(data, index, output): @@ -181,13 +181,14 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): return body -@nms.register(["cuda", "gpu"]) -def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1): +@non_max_suppression.register(["cuda", "gpu"]) +def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False, + topk=-1, id_index=0, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters ---------- - data: tvm.Tensor + data : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. @@ -195,15 +196,24 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk valid_count : tvm.Tensor 1-D tensor for valid number of boxes. - nms_threshold : float + return_indices : boolean + Whether to return box indices in input data. + + iou_threshold : optional, float Non-maximum suppression threshold. - force_suppress : boolean + force_suppress : optional, boolean Whether to suppress all detections regardless of class_id. - nms_topk : int + topk : optional, int Keep maximum top k detections before nms, -1 for no limit. + id_index : optional, int + index of the class categories, -1 to disable. + + invalid_to_bottom : optional, boolean + Whether to move all valid bounding boxes to the top. + Returns ------- out : tvm.Tensor @@ -216,14 +226,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk # An example to use nms dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") - valid_count = tvm.placeholder( - (dshape[0],), dtype="int32", name="valid_count") - nms_threshold = 0.7 + valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") + iou_threshold = 0.7 force_suppress = True - nms_topk = -1 - out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) - np_data = np.random.uniform(size=dshape).astype("float32") - np_valid_count = np.array([4]).astype("int32") + topk = -1 + out = nms(data, valid_count, iou_threshold, force_suppress, topk) + np_data = np.random.uniform(dshape) + np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "llvm") ctx = tvm.cpu() @@ -263,8 +272,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk tvm.extern(data.shape, [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( - ins[0], ins[1], ins[2], outs[0], nms_threshold, - force_suppress, nms_topk), + ins[0], ins[1], ins[2], outs[0], iou_threshold, + force_suppress, topk), dtype="float32", in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], tag="nms") diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index 746be092ebbe..11062824deb0 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -11,7 +11,7 @@ from topi.vision.ssd import multibox_prior from topi.vision.ssd import multibox_detection from topi.vision.ssd import multibox_transform_loc -from ..nms import nms +from ..nms import non_max_suppression def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): @@ -437,6 +437,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = nms( + out = non_max_suppression( inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) return out diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 17497abc0d8b..e3bc0fb9d547 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -162,3 +162,20 @@ def traverse(op): scheduled_ops.append(op) traverse(outs[0].op) return s + +@generic.schedule_get_valid_counts.register(["cuda", "gpu"]) +def schedule_get_valid_counts(outs): + """Schedule for get_valid_counts operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of get_valid_counts + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs) diff --git a/topi/python/topi/generic/vision.py b/topi/python/topi/generic/vision.py index 76e8545bfc52..bfd6c55d533a 100644 --- a/topi/python/topi/generic/vision.py +++ b/topi/python/topi/generic/vision.py @@ -36,6 +36,23 @@ def schedule_reorg(outs): cpp_target = cpp.TEST_create_target(target.target_name) return cpp.generic.default_schedule(cpp_target, outs, False) +@tvm.target.generic_func +def schedule_get_valid_counts(outs): + """Schedule for get_valid_counts + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of nms + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + @tvm.target.generic_func def schedule_nms(outs): """Schedule for non-maximum suppression diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 0ccc422010c1..1743de13fd85 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -20,3 +20,4 @@ from .gather_nd_python import gather_nd_python from .strided_slice_python import strided_slice_python from .batch_matmul import batch_matmul +from .slice_axis_python import slice_axis_python diff --git a/topi/python/topi/testing/slice_axis_python.py b/topi/python/topi/testing/slice_axis_python.py new file mode 100644 index 000000000000..589e5914a36c --- /dev/null +++ b/topi/python/topi/testing/slice_axis_python.py @@ -0,0 +1,34 @@ +"""Slice axis in python""" + +def slice_axis_python(data, axis, begin, end=None): + """Slice input array along specific axis. + + Parameters + ---------- + data : numpy.ndarray + The source array to be sliced. + + axis : int + Axis to be sliced. + + begin: int + The index to begin with in the slicing. + + end: int, optional + The index indicating end of the slice. + + Returns + ------- + ret : numpy.ndarray + The computed result. + """ + dshape = data.shape + if axis < 0: + axis += len(dshape) + if begin < 0: + begin += dshape[axis] + if end <= 0: + end += dshape[axis] + slc = [slice(None)] * len(dshape) + slc[axis] = slice(begin, end) + return data[tuple(slc)] diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index a41ee5b50089..169daea2d4d3 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -1,118 +1,247 @@ -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements """Non-maximum suppression operator""" import tvm -from tvm import api +from tvm import api, hybrid -def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): - """Low level IR routing for transform location in multibox_detection operator. +@hybrid.script +def hybrid_rearrange_out(data): + """Hybrid routine to rearrange nms output to + move all valid entries to top. Parameters ---------- - data: Buffer - Buffer of output boxes with class and score. + data : tvm.Tensor or numpy NDArray + NMS output. 3-D tensor with shape + [batch_size, num_anchors, 6]. - sort_result : Buffer - Buffer of output box indexes sorted by score. + Returns + ------- + output : tvm.Tensor or numpy NDArray + Transformed NMS output. 3-D tensor with shape + [batch_size, num_anchors, 6]. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + output = output_tensor((batch_size, + num_anchors, + elem_length), + data.dtype) - valid_count : Buffer - Buffer of number of valid output boxes. + for i in parallel(batch_size): + valid_idx = 0 + for j in range(num_anchors): + if data[i, j, 0] >= 0: + for k in range(elem_length): + output[i, valid_idx, k] = data[i, j, k] + valid_idx += 1 + if j >= valid_idx: + for k in range(elem_length): + output[i, j, k] = -1.0 + return output - out : Buffer - Output buffer. - nms_threshold : float - Non-maximum suppression threshold. +@hybrid.script +def hybrid_get_valid_counts(data, score_threshold): + """Hybrid routine to get valid count of bounding boxes + given a score threshold. Also moves valid boxes to the + top of input data. + + Parameters + ---------- + data : tvm.Tensor or numpy NDArray + Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + + score_threshold : tvm.const + Lower limit of score for valid bounding boxes. + + Returns + ------- + out_tensor : tvm.Tensor or numpy NDArray + Rearranged data tensor. + + valid_count : tvm.Tensor or numpy NDArray + 1-D tensor for valid number of boxes. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + valid_count = output_tensor((batch_size,), "int32") + out_tensor = output_tensor((batch_size, + num_anchors, + box_data_length), + data.dtype) + for i in parallel(batch_size): + valid_count[i] = 0 + for j in range(num_anchors): + score = data[i, j, 1] + if score > score_threshold: + for k in range(box_data_length): + out_tensor[i, valid_count[i], k] = data[i, j, k] + valid_count[i] += 1 + if j >= valid_count[i]: + for k in range(box_data_length): + out_tensor[i, j, k] = -1.0 + return valid_count, out_tensor + +@tvm.target.generic_func +def get_valid_counts(data, score_threshold=0): + """Get valid count of bounding boxes given a score threshold. + Also moves valid boxes to the top of input data. + + Parameters + ---------- + data : tvm.Tensor + Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + + score_threshold : optional, float + Lower limit of score for valid bounding boxes. + + Returns + ------- + out_tensor : tvm.Tensor + Rearranged data tensor. + + valid_count : tvm.Tensor + 1-D tensor for valid number of boxes. + """ + score_threshold_const = tvm.const(score_threshold, "float") + return hybrid_get_valid_counts(data, score_threshold_const) + + +@hybrid.script +def hybrid_nms(data, sorted_index, valid_count, + max_output_size, iou_threshold, force_suppress, + top_k, id_index): + """Hybrid routing for non-maximum suppression. + + Parameters + ---------- + data: tvm.Tensor or numpy NDArray + Bounding boxes with class and score. 3-D tensor with shape + [batch_size, num_anchors, 6]. + + sorted_index : tvm.Tensor or numpy NDArray + Bounding box indexes sorted by score, with shape + [batch_size, num_anchors]. + + valid_count : tvm.Tensor or numpy NDArray + 1-D tensor for valid number of boxes. - force_suppress : boolean + max_output_size : tvm.const + Max number of output valid boxes for each instance. + By default all valid boxes are returned. + + iou_threshold : tvm.const + Overlapping(IoU) threshold to suppress object with smaller score. + + force_suppress : tvm.const Whether to suppress all detections regardless of class_id. - nms_topk : int + top_k : tvm.const Keep maximum top k detections before nms, -1 for no limit. + id_index : tvm.const + index of the class categories, -1 to disable. + Returns ------- - stmt : Stmt - The result IR statement. + output : tvm.Tensor + 3-D tensor with shape [batch_size, num_anchors, 6]. + + box_indices: tvm.Tensor + 2-D tensor with shape [batch_size, num_anchors]. """ - def calculate_overlap(out_tensor, box_a_idx, box_b_idx): - """Calculate overlap of two boxes. - """ - w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - - tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) - h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - - tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) - i = w * h - u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ - (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ - (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ - (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i - return tvm.expr.Select(u <= 0.0, 0.0, i / u) - - ib = tvm.ir_builder.create() - p_data = ib.buffer_ptr(data) - p_sort_result = ib.buffer_ptr(sort_result) - p_valid_count = ib.buffer_ptr(valid_count) - p_out = ib.buffer_ptr(out) - batch_size = out.shape[0] - num_anchors = out.shape[1] - - nms_threshold_node = tvm.make.node("FloatImm", dtype="float32", value=nms_threshold) - nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk) - force_suppress_node = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) - with ib.for_range(0, batch_size, for_type="parallel", name="n") as n: - with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, - p_valid_count[0] > 0)): - # Reorder output - nkeep = tvm.if_then_else( - tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), - nms_topk, p_valid_count[n]) - with ib.for_range(0, nkeep, name="l") as l: - with ib.for_range(0, 6, name="m") as m: - p_out[(n * num_anchors * 6 - + l * 6 + m)] = p_data[(n * num_anchors * 6 - + p_sort_result[n * num_anchors + l] * 6 + m)] - with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])): - with ib.for_range(0, p_valid_count[n] - nkeep, name="l") as l: - with ib.for_range(0, 6, name="m") as m: - p_out[(n * num_anchors * 6 - + (l + nkeep) * 6 + m)] = p_data[(n * num_anchors * 6 - + (l + nkeep) * 6 + m)] + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + box_indices = output_tensor((batch_size, num_anchors), "int32") + output = output_tensor((batch_size, + num_anchors, + box_data_length,), + data.dtype) + + for i in parallel(batch_size): + if iou_threshold > 0: + if valid_count[i] > 0: + # Reorder output + nkeep = valid_count[i] + if 0 < top_k < nkeep: + nkeep = top_k + for j in range(nkeep): + for k in range(box_data_length): + output[i, j, k] = data[i, sorted_index[i, j], k] + box_indices[i, j] = sorted_index[i, j] + if 0 < top_k < valid_count[i]: + for j in range(valid_count[i] - nkeep): + for k in range(box_data_length): + output[i, j + nkeep, k] = -1.0 + box_indices[i, j + nkeep] = -1 # Apply nms - with ib.for_range(0, p_valid_count[n], name="l") as l: - offset_l = l * 6 - with ib.if_scope(p_out[n * num_anchors * 6 + offset_l] >= 0): - with ib.for_range(0, p_valid_count[n], name="m") as m: - offset_m = m * 6 - with ib.if_scope(tvm.all(m > l, p_out[n * num_anchors * 6 - + offset_m] >= 0)): - with ib.if_scope(tvm.any(force_suppress_node > 0, - p_out[n * num_anchors * 6 + offset_l] == - p_out[n * num_anchors * 6 + offset_m])): - # When force_suppress == True or class_id equals - iou = calculate_overlap(p_out, n * num_anchors * 6 + offset_l + 2, - n * num_anchors * 6 + offset_m + 2) - with ib.if_scope(iou >= nms_threshold): - p_out[n * num_anchors * 6 + offset_m] = -1.0 - with ib.else_scope(): - with ib.for_range(0, p_valid_count[n], name="l") as l: - with ib.for_range(0, 6, name="m") as m: - p_out[(n * num_anchors * 6 - + l * 6 + m)] = p_data[n * num_anchors * 6 + l * 6 + m] + for j in range(valid_count[i]): + if output[i, j, 0] >= 0: + for k in range(valid_count[i]): + check_iou = 0 + if k > j and output[i, k, 0] >= 0: + if force_suppress: + check_iou = 1 + elif id_index < 0 or output[i, j, 0] == output[i, k, 0]: + check_iou = 1 + if check_iou > 0: + batch_idx = i + box_a_idx = j + box_b_idx = k + box_start_idx = 2 + a_t = output[batch_idx, box_a_idx, box_start_idx + 1] + a_b = output[batch_idx, box_a_idx, box_start_idx + 3] + a_l = output[batch_idx, box_a_idx, box_start_idx] + a_r = output[batch_idx, box_a_idx, box_start_idx + 2] + b_t = output[batch_idx, box_b_idx, box_start_idx + 1] + b_b = output[batch_idx, box_b_idx, box_start_idx + 3] + b_l = output[batch_idx, box_b_idx, box_start_idx] + b_r = output[batch_idx, box_b_idx, box_start_idx + 2] + w = max(0.0, min(a_r, b_r) - max(a_l, b_l)) + h = max(0.0, min(a_b, b_b) - max(a_t, b_t)) + area = h * w + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + iou = 0.0 if u <= 0.0 else area / u + if iou >= iou_threshold: + output[i, k, 0] = -1.0 + box_indices[i, k] = -1 + else: + for j in range(valid_count[i]): + for k in range(box_data_length): + output[i, j, k] = data[i, j, k] + box_indices[i, j] = j # Set invalid entry to be -1 - with ib.for_range(0, num_anchors - p_valid_count[n], name="l") as l: - with ib.for_range(0, 6, name="m") as m: - p_out[n * num_anchors * 6 + (l + p_valid_count[n]) * 6 + m] = -1.0 - return ib.get() + for j in range(num_anchors - valid_count[i]): + for k in range(box_data_length): + output[i, j + valid_count[i], k] = -1.0 + box_indices[i, j + valid_count[i]] = -1 + # Only return max_output_size valid boxes + num_valid_boxes = 0 + if max_output_size > 0: + for j in range(valid_count[i]): + if output[i, j, 0] >= 0: + if num_valid_boxes == max_output_size: + for k in range(box_data_length): + output[i, j, k] = -1.0 + box_indices[i, j] = -1 + else: + num_valid_boxes += 1 + return output, box_indices @tvm.target.generic_func -def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1): +def non_max_suppression(data, valid_count, max_output_size=-1, + iou_threshold=0.5, force_suppress=False, top_k=-1, + id_index=0, return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters ---------- - data: tvm.Tensor + data : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. @@ -120,15 +249,28 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) valid_count : tvm.Tensor 1-D tensor for valid number of boxes. - nms_threshold : float + max_output_size : optional, int + Max number of output valid boxes for each instance. + By default all valid boxes are returned. + + iou_threshold : optional, float Non-maximum suppression threshold. - force_suppress : boolean + force_suppress : optional, boolean Whether to suppress all detections regardless of class_id. - nms_topk : int + top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. + id_index : optional, int + index of the class categories, -1 to disable. + + return_indices : optional, boolean + Whether to return box indices in input data. + + invalid_to_bottom : optional, boolean + Whether to move all valid bounding boxes to the top. + Returns ------- out : tvm.Tensor @@ -138,16 +280,17 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) -------- .. code-block:: python - # An example to use nms + # An example to use non_max_suppression dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") - nms_threshold = 0.7 + iou_threshold = 0.7 force_suppress = True - nms_topk = -1 - out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) - np_data = np.random.uniform(size=dshape).astype("float32") - np_valid_count = np.array([4]).astype("int32") + top_k = -1 + out = non_max_suppression(data, valid_count, iou_threshold=iou_threshold, + force_suppress=force_suppress, top_k=top_k) + np_data = np.random.uniform(dshape) + np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "llvm") ctx = tvm.cpu() @@ -161,7 +304,6 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) - data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) score_axis = 1 score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) @@ -180,13 +322,13 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) in_buffers=[score_tensor_buf, valid_count_buf], out_buffers=sort_tensor_buf, name="nms_sort") - out = \ - tvm.extern(data.shape, - [data, sort_tensor, valid_count], - lambda ins, outs: nms_ir( - ins[0], ins[1], ins[2], outs[0], nms_threshold, - force_suppress, nms_topk), - dtype="float32", - in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], - tag="nms") - return out + out, box_indices = hybrid_nms(data, sort_tensor, valid_count, + tvm.const(max_output_size, dtype="int32"), + tvm.const(iou_threshold, dtype="float32"), + tvm.const(force_suppress, dtype="bool"), + tvm.const(top_k, dtype="int32"), + tvm.const(id_index, dtype="int32")) + if not return_indices and invalid_to_bottom: + out = hybrid_rearrange_out(out) + + return box_indices if return_indices else out diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index f1de42430dd6..2de1723dbd7b 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -1,75 +1,76 @@ -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable """SSD multibox operators""" from __future__ import absolute_import as _abs -import math import tvm -from tvm import api +from tvm import hybrid +from tvm.intrin import exp, sqrt import topi -from ..nms import nms +from ..nms import non_max_suppression -def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): - """Low level IR routing for multibox_prior operator. +@hybrid.script +def hybrid_multibox_prior(data, sizes, ratios, steps, offsets): + """Hybrid routing for multibox_prior operator. Parameters ---------- - data : Buffer - Input data buffer. + data : tvm.Tensor or numpy NDArray + 4-D tensor with shape [batch, channel, height, width]] - out : Buffer - Output buffer. + sizes : tvm ConsExpr + Sizes for anchor boxes. - sizes : tuple of float - Tuple of sizes for anchor boxes. - - ratios : tuple of float - Tuple of ratios for anchor boxes. + ratios : tvm ConsExpr + Ratios for anchor boxes. - steps : Tuple of float + steps : tvm ConsExpr Priorbox step across y and x, -1 for auto calculation. - offsets : tuple of int + offsets : tvm ConsExpr Priorbox center offsets, y and x respectively. Returns ------- - stmt : Stmt - The result IR statement. + output : tvm.Tensor or numpy NDArray + 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4] """ - ib = tvm.ir_builder.create() - p_out = ib.buffer_ptr(out) in_height = data.shape[2] in_width = data.shape[3] num_sizes = len(sizes) num_ratios = len(ratios) - size_ratio_concat = sizes + ratios - steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height - steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width + num_boxes = in_height * in_width * (num_sizes + num_ratios - 1) + output = output_tensor((1, num_boxes, 4), "float32") + steps_h = steps[0] * 1.0 if steps[0] > 0 else 1.0 / in_height + steps_w = steps[1] * 1.0 if steps[1] > 0 else 1.0 / in_width offset_h = offsets[0] offset_w = offsets[1] - with ib.for_range(0, in_height, for_type="parallel", name="i") as i: + # Need to define var out of const_range + if + w = 0.0 + h = 0.0 + + for i in parallel(in_height): center_h = (i + offset_h) * steps_h - with ib.for_range(0, in_width, name="j") as j: + for j in range(in_width): center_w = (j + offset_w) * steps_w - for k in range(num_sizes + num_ratios - 1): - w = tvm.if_then_else(k < num_sizes, - size_ratio_concat[k] * in_height / in_width / 2.0, - size_ratio_concat[0] * in_height / in_width * - math.sqrt(size_ratio_concat[k + 1]) / 2.0) - h = tvm.if_then_else( - k < num_sizes, size_ratio_concat[k] / 2.0, - size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) - count = (i * in_width * (num_sizes + num_ratios - 1) + - j * (num_sizes + num_ratios - 1) + k) * 4 - p_out[count] = center_w - w - p_out[count + 1] = center_h - h - p_out[count + 2] = center_w + w - p_out[count + 3] = center_h + h - - return ib.get() + for k in const_range(num_sizes + num_ratios - 1): + if k < num_sizes: + w = sizes[k] * in_height / in_width / 2.0 + h = sizes[k] / 2.0 + else: + w = sizes[0] * in_height / in_width \ + * sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0 + h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0 + count = i * in_width * (num_sizes + num_ratios - 1) \ + + j * (num_sizes + num_ratios - 1) + k + output[0, count, 0] = center_w - w + output[0, count, 1] = center_h - h + output[0, count, 2] = center_w + w + output[0, count, 3] = center_h + h + + return output @tvm.target.generic_func @@ -101,115 +102,120 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, out : tvm.Tensor 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4] """ - num_sizes = len(sizes) - num_ratios = len(ratios) - oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4) - out = tvm.extern(oshape, [data], lambda ins, outs: - multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets), - tag="multibox_prior") + out = hybrid_multibox_prior(data, tvm.convert(sizes), tvm.convert(ratios), + tvm.convert(steps), tvm.convert(offsets)) if clip: out = topi.clip(out, 0, 1) return out -def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances): - """Low level IR routing for transform location in multibox_detection operator. +@hybrid.script +def _hybridy_transform_loc(box, pred_loc, variance, clip): + """Transform prior anchor box to output box through location predictions. + """ + al = box[0] + at = box[1] + ar = box[2] + ab = box[3] + + px = pred_loc[0] + py = pred_loc[1] + pw = pred_loc[2] + ph = pred_loc[3] + + vx = variance[0] + vy = variance[1] + vw = variance[2] + vh = variance[3] + + output = output_tensor((4,), pred_loc.dtype) + + aw = ar - al + ah = ab - at + ax = (al + ar) / 2.0 + ay = (at + ab) / 2.0 + ox = px * vx * aw + ax + oy = py * vy * ah + ay + ow = exp(pw * vw) * aw / 2.0 + oh = exp(ph * vh) * ah / 2.0 + output[0] = max(0.0, min(1.0, ox - ow)) if clip else ox - ow + output[1] = max(0.0, min(1.0, oy - oh)) if clip else oy - oh + output[2] = max(0.0, min(1.0, ox + ow)) if clip else ox + ow + output[3] = max(0.0, min(1.0, oy + oh)) if clip else oy + oh + return output + +@hybrid.script +def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, + clip, threshold, variances): + """Hybrid routing for transform location in multibox_detection operator. Parameters ---------- - cls_prob : Buffer - Buffer of class probabilities. + cls_prob : tvm.Tensor or numpy NDArray + 3-D tensor of class probabilities. - loc_pred : Buffer - Buffer of location regression predictions. + loc_pred : tvm.Tensor or numpy NDArray + 2-D tensor of location regression predictions. - anchor : Buffer - Buffer of prior anchor boxes. + anchor : tvm.Tensor or numpy NDArray + 3-D tensor of prior anchor boxes. - valid_count : Buffer - Buffer of number of valid output boxes. - - out : Buffer - Output buffer. - - clip : boolean + clip : tvm.const Whether to clip out-of-boundary boxes. - threshold : float + threshold : tvm.const Threshold to be a positive prediction. - variances : tuple of float + variances : tvm.ndarray Variances to be decoded from box regression output. Returns ------- - stmt : Stmt - The result IR statement. - """ - def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh): - """Transform prior anchor box to output box through location predictions. - """ - al = anchor[anchor_base_idx] - at = anchor[anchor_base_idx + 1] - ar = anchor[anchor_base_idx + 2] - ab = anchor[anchor_base_idx + 3] - aw = ar - al - ah = ab - at - ax = (al + ar) / 2.0 - ay = (at + ab) / 2.0 - px = loc[loc_base_idx] - py = loc[loc_base_idx + 1] - pw = loc[loc_base_idx + 2] - ph = loc[loc_base_idx + 3] - ox = px * vx * aw + ax - oy = py * vy * ah + ay - ow = tvm.exp(pw * vw) * aw / 2.0 - oh = tvm.exp(ph * vh) * ah / 2.0 - return tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \ - tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \ - tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \ - tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh) + out_loc : tvm.Tensor or numpy NDArray + 3-D tensor of transformed location. + valid_count : tvm.Tensor or numpy NDArray + 1_d tensor of valid counts for boxes. + """ batch_size = cls_prob.shape[0] num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] - - ib = tvm.ir_builder.create() - p_cls_prob = ib.buffer_ptr(cls_prob) - p_loc_pred = ib.buffer_ptr(loc_pred) - p_anchor = ib.buffer_ptr(anchor) - p_valid_count = ib.buffer_ptr(valid_count) - p_out = ib.buffer_ptr(out) - with ib.for_range(0, batch_size, for_type="parallel", name="n") as n: - p_valid_count[n] = 0 - with ib.for_range(0, num_anchors, name="i") as i: + box_coord = allocate((4,), loc_pred.dtype) + pred_coord = allocate((4,), loc_pred.dtype) + out_loc = output_tensor((batch_size, num_anchors, 6), + loc_pred.dtype) + valid_count = output_tensor((batch_size,), "int32") + + for i in parallel(batch_size): + valid_count[i] = 0 + for j in range(num_anchors): # Find the predicted class id and probability - score = ib.allocate('float32', (1,), name="score", scope="local") - cls_id = ib.allocate('int32', (1,), name="id", scope="local") - score[0] = -1.0 - cls_id[0] = 0 - with ib.for_range(0, num_classes, name="j") as j: - with ib.if_scope(j > 0): - temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i] - cls_id[0] = tvm.if_then_else(temp > score[0], j, cls_id[0]) - score[0] = tvm.max(temp, score[0]) - with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)): - cls_id[0] = 0 + score = -1.0 + cls_id = 0 + for k in range(num_classes): + if k > 0: + temp = cls_prob[i, k, j] + cls_id = k if temp > score else cls_id + score = max(temp, score) + if cls_id > 0 and score < threshold: + cls_id = 0 # [id, prob, xmin, ymin, xmax, ymax] # Remove background, restore original id - with ib.if_scope(cls_id[0] > 0): - out_base_idx = n * num_anchors * 6 + p_valid_count[n] * 6 - p_out[out_base_idx] = cls_id[0] - 1.0 - p_out[out_base_idx + 1] = score[0] - offset = i * 4 - p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ - p_out[out_base_idx + 5] = transform_loc(p_loc_pred, n * num_anchors * 4 + offset, - p_anchor, offset, clip, variances[0], - variances[1], variances[2], variances[3]) - p_valid_count[n] += 1 - - return ib.get() - + if cls_id > 0: + out_loc[i, valid_count[i], 0] = cls_id - 1.0 + out_loc[i, valid_count[i], 1] = score + for l in range(4): + box_coord[l] = anchor[0, j, l] + pred_coord[l] = loc_pred[i, j * 4 + l] + out_coord = _hybridy_transform_loc(box_coord, pred_coord, + variances, clip) + out_loc[i, valid_count[i], 2] = out_coord[0] + out_loc[i, valid_count[i], 3] = out_coord[1] + out_loc[i, valid_count[i], 4] = out_coord[2] + out_loc[i, valid_count[i], 5] = out_coord[3] + valid_count[i] += 1 + + return out_loc, valid_count @tvm.target.generic_func def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, @@ -240,24 +246,10 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 ------- ret : tuple of tvm.Tensor """ - batch_size = cls_prob.shape[0] - num_anchors = anchor.shape[1] - oshape = (batch_size, num_anchors, 6) - # Define data alignment for intermediate buffer - valid_count_dtype = "int32" - valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, - "valid_count_buf", data_alignment=4) - out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8) - valid_count, out = \ - tvm.extern([(batch_size,), oshape], - [cls_prob, loc_pred, anchor], - lambda ins, outs: transform_loc_ir( - ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances), - dtype=[valid_count_dtype, cls_prob.dtype], - out_buffers=[valid_count_buf, out_buf], - tag="multibox_transform_loc") - return [out, valid_count] - + return hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, + tvm.const(clip, "bool"), + tvm.const(threshold, "float32"), + tvm.convert(variances)) @tvm.target.generic_func def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5, @@ -300,5 +292,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) + out = non_max_suppression(inter_out[0], inter_out[1], -1, + nms_threshold, force_suppress, nms_topk, + return_indices=False) return out diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 3c0c3aa854d7..02e04212b63e 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -8,11 +8,62 @@ from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -from topi.vision import ssd, nms +from topi.vision import ssd, non_max_suppression, get_valid_counts + + +def verify_get_valid_counts(dshape, score_threshold): + dtype = "float32" + batch_size, num_anchor, elem_length = dshape + np_data = np.random.uniform(size=dshape).astype(dtype) + np_out1 = np.zeros(shape=(batch_size,)) + np_out2 = np.zeros(shape=dshape).astype(dtype) + for i in range(batch_size): + np_out1[i] = 0 + inter_idx = 0 + for j in range(num_anchor): + score = np_data[i, j, 1] + if score > score_threshold: + for k in range(elem_length): + np_out2[i, inter_idx, k] = np_data[i, j, k] + np_out1[i] += 1 + inter_idx += 1 + if j >= np_out1[i]: + for k in range(elem_length): + np_out2[i, j, k] = -1.0 + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + data = tvm.placeholder(dshape, name="data", dtype=dtype) + outs = get_valid_counts(data, score_threshold) + s = topi.generic.schedule_multibox_prior(outs) + + tvm_input_data = tvm.nd.array(np_data, ctx) + tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) + tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx) + f = tvm.build(s, [data, outs[0], outs[1]], device) + f(tvm_input_data, tvm_out1, tvm_out2) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) -def test_nms(): + for device in ['llvm']: + check_device(device) + + +def test_get_valid_counts(): + verify_get_valid_counts((1, 2500, 6), 0) + verify_get_valid_counts((1, 2500, 6), -1) + verify_get_valid_counts((3, 1000, 6), 0.55) + verify_get_valid_counts((16, 500, 6), 0.95) + + +def test_non_max_suppression(): dshape = (1, 5, 6) + indices_dshape = (1, 5) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") nms_threshold = 0.7 @@ -24,8 +75,9 @@ def test_nms(): [1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype) np_valid_count = np.array([4]).astype(valid_count.dtype) np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) def check_device(device): ctx = tvm.context(device, 0) @@ -35,18 +87,27 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): if device == 'llvm': - out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) + out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) + indices_out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) else: - out = topi.cuda.nms(data, valid_count, nms_threshold, force_suppress, nms_topk) + out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) + indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) s = topi.generic.schedule_nms(out) + indices_s = topi.generic.schedule_nms(indices_out) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f = tvm.build(s, [data, valid_count, out], device) f(tvm_data, tvm_valid_count, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) + tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx) + f = tvm.build(indices_s, [data, valid_count, indices_out], device) + f(tvm_data, tvm_valid_count, tvm_indices_out) + tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) + for device in ['llvm']: check_device(device) @@ -274,7 +335,8 @@ def test_proposal(): if __name__ == "__main__": - test_nms() + test_get_valid_counts() + test_non_max_suppression() test_multibox_prior() test_multibox_detection() test_roi_align() diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py new file mode 100644 index 000000000000..6a5d63b9f8cf --- /dev/null +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -0,0 +1,104 @@ +""" +Deploy Single Shot Multibox Detector(SSD) model +=============================================== +**Author**: `Yao Wang `_ + +This article is an introductory tutorial to deploy SSD models with TVM. +We will use GluonCV pre-trained SSD model and convert it to Relay IR +""" +import tvm + +from matplotlib import pyplot as plt +from nnvm import compiler +from nnvm.frontend import from_mxnet +from nnvm.testing.config import ctx_list +from tvm import relay +from tvm.contrib import graph_runtime +from gluoncv import model_zoo, data, utils + + +###################################################################### +# Preliminary and Set parameters +# ------------------------------ +# We should build TVM with sort support, in TVM root directory +# +# .. code-block:: bash +# +# echo "set(USE_SORT ON)" > config.mk +# make -j8 +# +# .. note:: +# +# Currently we support compiling SSD on CPU only. +# GPU support is in progress. +# +# To get best inference performance on CPU, change +# target argument according to your device and +# follow the :ref:`tune_relay_x86` to tune x86 CPU and +# :ref:`tune_relay_arm` for arm cpu. +# +# SSD with VGG as body network is not supported yet since +# x86 conv2d schedule doesn't support dilation. + +supported_model = [ + 'ssd_512_resnet18_v1_voc', + 'ssd_512_resnet18_v1_coco', + 'ssd_512_resnet50_v1_voc', + 'ssd_512_resnet50_v1_coco', + 'ssd_512_resnet101_v2_voc', + 'ssd_512_mobilenet1_0_voc', + 'ssd_512_mobilenet1_0_coco', +] + +model_name = "ssd_512_resnet50_v1_voc" +dshape = (1, 3, 512, 512) +dtype = "float32" +target_list = ctx_list() + +###################################################################### +# Download and pre-process demo image + +im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' + + 'gluoncv/detection/street_small.jpg?raw=true', + path='street_small.jpg') +x, img = data.transforms.presets.ssd.load_test(im_fname, short=512) + +###################################################################### +# Convert and compile model for CPU. + +block = model_zoo.get_model(model_name, pretrained=True) + +def compile(target): + net, params = relay.frontend.from_mxnet(block, {"data": dshape}) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(net, target, params=params) + return graph, lib, params + +###################################################################### +# Create TVM runtime and do inference + +def run(graph, lib, params, ctx): + # Build TVM runtime + m = graph_runtime.create(graph, lib, ctx) + tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx) + m.set_input('data', tvm_input) + m.set_input(**params) + # execute + m.run() + # get outputs + class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2) + return class_IDs, scores, bounding_boxs + +for target, ctx in target_list: + if target == "cuda": + print("GPU not supported yet, skip.") + continue + graph, lib, params = compile(target) + class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx) + +###################################################################### +# Display result + +ax = utils.viz.plot_bbox(img, bounding_boxs.asnumpy()[0], scores.asnumpy()[0], + class_IDs.asnumpy()[0], class_names=block.classes) +plt.show() diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd_mxnet.py similarity index 98% rename from tutorials/nnvm/deploy_ssd.py rename to tutorials/nnvm/deploy_ssd_mxnet.py index eadb8fd28e0c..1a71c96eaa0c 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd_mxnet.py @@ -61,7 +61,7 @@ image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ "cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg" inference_symbol_folder = \ -"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" + "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \ "archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"