From e77a68ed67dad4c2125cd2c14860c8ce154d73d2 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 8 Oct 2020 14:00:24 -0600 Subject: [PATCH 01/10] NMS partially working on CPU, fails on GPU --- python/tvm/relay/frontend/onnx.py | 116 +++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 67 +++++++++++- 2 files changed, 181 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f0d7e2d21d40..2f40c68457d1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2303,6 +2303,121 @@ def _impl_v1(cls, inputs, attr, params): return _expr.If(cond, then_expr, else_expr) +class NonMaxSuppression(OnnxOpConverter): + """Operator converter for NonMaxSuppression.""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + # Get parameter values + boxes = inputs[0] + scores = inputs[1] + max_output_boxes_per_class = inputs["max_output_boxes_per_class"] + iou_threshold = inputs["iou_threshold"] + score_threshold = inputs["score_threshold"] + print(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) + if "center_point_box" in attr: + assert ( + attr["center_point_box"] == 0 + ), "Only support center_point_box = 0 in onnx importer right now" + + def pad_last_dim(x): + return _op.expand_dims(x, -1, 1) + + # if iou_threshold is not None: + # # assume iou_threshold is constant + # iou_threshold = np.atleast_1d(iou_threshold.data.asnumpy())[0] + # else: + # iou_threshold = 0.0 + iou_threshold = 0.8 + if score_threshold is not None: + # assume iou_threshold is constant + score_threshold = np.atleast_1d(score_threshold.data.asnumpy())[0] + else: + score_threshold = 0.0 + # loop over classes + B, C, S = infer_shape(scores) + out = [] + for i in range(C): + class_scores = _op.strided_slice(scores, [0, i, 0], [B, i + 1, S], [1, 1, 1]) + class_scores = pad_last_dim(_op.squeeze(class_scores)) + data = _op.concatenate([class_scores, boxes], -1) + ct, data, indices = _op.vision.get_valid_counts( + data, score_threshold=score_threshold, id_index=-1, score_index=0 + ) + # reason why using get_valid_counts is for inference performance + # NNX NMS doesn't have parameter top_k + top_k = -1 + # TF doesn't have class id for nms input + score_index = 0 + nms_ret = _op.vision.non_max_suppression( + data=data, + valid_count=ct, + indices=indices, + max_output_size=max_output_boxes_per_class, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False, + ) + nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) + onnx_output = _op.concatenate( + [ + pad_last_dim( + _op.broadcast_to( + pad_last_dim( + _op.arange( + _op.const(infer_shape(nms_padded_out)[0]), + dtype=infer_type(nms_padded_out).checked_type.dtype, + ), + ), + infer_shape(nms_ret[0]), + ), + ), + _op.broadcast_to( + _op.const(i, dtype=infer_type(nms_padded_out).checked_type.dtype), + infer_shape(nms_padded_out), + ), + nms_padded_out, + ], + -1, + ) + nms_size = _op.cast(nms_ret[1], "int64") + for batch in range(B): + start = [batch, 0, 0] + end = _op.concatenate( + [ + _op.const( + np.array( + [ + batch + 1, + ] + ), + dtype="int64", + ), + _op.reshape( + _op.strided_slice(nms_size, [batch, 0], [batch + 1, 1], [1, 1, 1]), [1] + ), + _op.const( + np.array( + [ + 3, + ] + ), + dtype="int64", + ), + ], + 0, + ) + out += [_op.squeeze(_op.strided_slice(onnx_output, start, end, [1, 1, 1]), [0])] + + out = [out[i * B + j] for j in range(B) for i in range(C)] + return _op.concatenate(out, axis=0) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2415,6 +2530,7 @@ def _get_convert_map(opset): # defs/vision "MaxRoiPool": MaxRoiPool.get_converter(opset), "RoiAlign": RoiAlign.get_converter(opset), + "NonMaxSuppression": NonMaxSuppression.get_converter(opset), # defs/reduction "ReduceMax": ReduceMax.get_converter(opset), "ReduceMin": ReduceMin.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d7a07f7271a9..ce652be3b725 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -147,6 +147,8 @@ def flatten(out): for target in targets: ctx = tvm.context(target, 0) + print(target, ctx) + if target == "cuda" or target == "nvptx":continue if use_vm: tvm_out = get_tvm_output_with_vm( model, @@ -159,7 +161,7 @@ def flatten(out): ) else: tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset) - + print(ort_out, tvm_out) tvm.testing.assert_allclose(flatten(ort_out), flatten(tvm_out), rtol=rtol, atol=atol) @@ -2821,7 +2823,6 @@ def forward(self, input): def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"): - print(x_shape, kernel_shape, strides, mode, pads, auto_pad) x_np = np.random.uniform(size=x_shape).astype("float32") if mode == "max": @@ -3689,6 +3690,68 @@ def verify_roi_align( verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) +#@tvm.testing.uses_gpu +def test_non_max_suppression(): + boxes = np.array( [0.0, 0.0, 0.3, 0.3, + 0.0, 0.0, 0.4, 0.4, + 0.0, 0.0, 0.5, 0.5, + 0.5, 0.5, 0.9, 0.9, + 0.5, 0.5, 1.0, 1.0, + + 0.0, 0.0, 0.3, 0.3, + 0.0, 0.0, 0.4, 0.4, + 0.5, 0.5, 0.95, 0.95, + 0.5, 0.5, 0.96, 0.96, + 0.5, 0.5, 1.0, 1.0]).reshape([2, 5, 4]).astype("float32") + scores = np.array( [0.1, 0.2, 0.6, 0.3, 0.9, + 0.1, 0.2, 0.6, 0.3, 0.9, + + 0.1, 0.2, 0.6, 0.3, 0.9, + 0.1, 0.2, 0.6, 0.3, 0.9]).reshape([2, 2, 5]).astype("float32") + max_output_boxes_per_class = np.array(2).astype("int64") + iou_threshold = np.array(0.8).astype("float32") + output_dims = [8, 3] + #test.AddInput("iou_threshold", {}, {0.8f}); + #test.AddOutput("selected_indices", {8, 3}, + # {0L, 0L, 4L, + # 0L, 0L, 2L, + # 0L, 1L, 4L, + # 0L, 1L, 2L, + + # 1L, 0L, 4L, + # 1L, 0L, 1L, + # 1L, 1L, 4L, + # 1L, 1L, 1L}); + node = helper.make_node( + "NonMaxSuppression", + inputs=["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"], + outputs=["Y"], + center_point_box=0, + ) + + graph = helper.make_graph( + [node], + "nms_test", + inputs=[ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes.shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores.shape), + helper.make_tensor_value_info( + "max_output_boxes_per_class", + TensorProto.INT64, + max_output_boxes_per_class.shape), + helper.make_tensor_value_info( + "iou_threshold", + TensorProto.FLOAT, + iou_threshold.shape + ), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, output_dims)], + ) + + model = helper.make_model(graph, producer_name="nms_test") + + + verify_with_ort_with_inputs(model, [boxes, scores, max_output_boxes_per_class, iou_threshold], use_vm=True) def verify_cond_loop(): y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) From d121d254006287217f443eae63c9354ed7116674 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 15 Oct 2020 12:58:12 -0600 Subject: [PATCH 02/10] support dynamic iou_threshold --- include/tvm/relay/attrs/vision.h | 3 +-- python/tvm/relay/frontend/onnx.py | 14 +++----------- python/tvm/relay/op/strategy/generic.py | 4 +++- python/tvm/relay/op/vision/nms.py | 4 +++- python/tvm/topi/vision/nms.py | 9 ++++++--- src/relay/op/vision/nms.cc | 14 +++++++------- 6 files changed, 23 insertions(+), 25 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 2b905f5bd04b..ff5af8032e8d 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -89,7 +89,7 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in non_maximum_suppression operator */ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { Optional max_output_size; - double iou_threshold; + Optional iou_threshold; bool force_suppress; int top_k; int coord_start; @@ -101,7 +101,6 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 5); + ICHECK_EQ(types.size(), 6); const auto* data = types[0].as(); const auto* valid_count = types[1].as(); const NonMaximumSuppressionAttrs* param = attrs.as(); @@ -90,18 +90,17 @@ bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, fields.push_back(TensorType(oshape, DataType::Int(32))); std::vector countshape({dshape[0], 1}); fields.push_back(TensorType(countshape, DataType::Int(32))); - reporter->Assign(types[4], TupleType(Array(fields))); + reporter->Assign(types[5], TupleType(Array(fields))); } else { - reporter->Assign(types[4], TensorType(dshape, data->dtype)); + reporter->Assign(types[5], TensorType(dshape, data->dtype)); } return true; } -Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, double iou_threshold, +Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, Expr iou_threshold, bool force_suppress, int top_k, int coord_start, int score_index, int id_index, bool return_indices, bool invalid_to_bottom) { auto attrs = make_object(); - attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; attrs->top_k = top_k; attrs->coord_start = coord_start; @@ -110,7 +109,7 @@ Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, do attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; static const Op& op = Op::Get("vision.non_max_suppression"); - return Call(op, {data, valid_count, indices, max_output_size}, Attrs(attrs), {}); + return Call(op, {data, valid_count, indices, max_output_size, iou_threshold}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS); @@ -121,11 +120,12 @@ be in the format of [class_id, score, left, top, right, bottom] or [score, left, top, right, bottom]. Set id_index to be -1 to ignore class_id axis. )doc" TVM_ADD_FILELINE) - .set_num_inputs(4) + .set_num_inputs(5) .add_argument("data", "Tensor", "Input data.") .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.") .add_argument("max_output_size", "Tensor", "Max number of output valid boxes.") + .add_argument("iou_threshold", "Tensor", "Threshold for box overlap.") .set_support_level(5) .add_type_rel("NMS", NMSRel); From 2bb9949bbd9a79c64c4a26e764327e69d83decde Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 23 Oct 2020 10:18:34 -0600 Subject: [PATCH 03/10] WIP NMS with while loops --- python/tvm/relay/backend/_backend.py | 7 ++ python/tvm/relay/frontend/onnx.py | 124 ++++++++++++++++++++- src/relay/op/vision/nms.cc | 3 + src/runtime/vm/vm.cc | 15 +++ tests/python/frontend/onnx/test_forward.py | 1 - 5 files changed, 147 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 65b0c0ba87c7..bf4577a8924a 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -88,6 +88,13 @@ def _tensor_value_repr(tvalue): return str(tvalue.data.asnumpy()) +@tvm._ffi.register_func("relay._ndarray_repr") +def _tensor_constant_repr(tvalue): + tmp = tvalue.asnumpy() + return "NDArray of shape " + str(tmp.shape) + " and dtype " + str(tmp.dtype) +"\n\t" + str(tmp) + + + @tvm._ffi.register_func("relay._constant_repr") def _tensor_constant_repr(tvalue): dtype = tvm.runtime.DataType(tvalue.data.dtype) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4da9c7e13f5b..c726b4eeedcb 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2314,6 +2314,126 @@ def _impl_v10(cls, inputs, attr, params): max_output_boxes_per_class = inputs["max_output_boxes_per_class"] iou_threshold = inputs["iou_threshold"] score_threshold = inputs["score_threshold"] + + dtype = infer_type(boxes).checked_type.dtype + + if "center_point_box" in attr: + assert ( + attr["center_point_box"] == 0 + ), "Only support center_point_box = 0 in onnx importer right now" + + if iou_threshold is None: + iou_threshold = 0.0 + if score_threshold is None: + score_threshold = 0.0 + + zero = _op.const(np.array([0]), dtype="int64") + one = _op.const(np.array([1]), dtype="int64") + three = _op.const(np.array([1]), dtype="int64") + three_ones = _op.const(np.array([1, 1, 1]), dtype="int64") + two_ones = _op.const(np.array([1, 1]), dtype="int64") + + batch = _expr.var("batch", shape=(1,), dtype="int64") + B = _expr.var("B", shape=(1,), dtype="int64") + nms_size = _expr.var("nms_size", shape=(_ty.Any(), _ty.Any()), dtype="int64") + onnx_output = _expr.var("onnx_output", shape=(_ty.Any(), _ty.Any(), 3), dtype="int64") + out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") + + def _inner_cond(batch, B, nms_size, onnx_output, out): + return _op.min(_op.less(batch, B)) + + def _inner_body(batch, B, nms_size, onnx_output, out): + start = _op.concatenate([batch, zero], axis=0) + end = _op.concatenate([batch + one, one], axis=0) + num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, two_ones), [1]) + start = _op.concatenate([batch, zero, zero], axis=0) + end = _op.concatenate([batch + one, num_valid_boxes, three], axis=0) + new_out = _op.squeeze(_op.strided_slice(onnx_output, start, end, three_ones), [0]) + #new_out = _op.const(np.array([1,1,1]).reshape([1,3]), dtype="int64") + print(infer_type(new_out), infer_type(out)) + return batch + one, B, nms_size, onnx_output, _op.concatenate([out, new_out], axis=0) + + inner_loop = _loops.while_loop(_inner_cond, [batch, B, nms_size, onnx_output, out], _inner_body) + + def pad_last_dim(x): + return _op.expand_dims(x, -1, 1) + + #Initial Values + + i = _expr.var("i", shape=(1, ), dtype="int64") + scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype) + boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype) + B = _expr.var("B", shape=(1,), dtype="int64") + C = _expr.var("C", shape=(1,), dtype="int64") + S = _expr.var("S", shape=(1,), dtype="int64") + out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") + + def _outer_cond(i, scores, boxes, B, C, S, out): + return _op.min(_op.less(i, C)) + + def _outer_body(i, scores, boxes, B, C, S, out): + begin = _op.concatenate([zero, i, zero], axis=0) + end = _op.concatenate([B, i + one, S], axis=0) + class_scores = _op.strided_slice(scores, begin, end, three_ones) + class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1) + data = _op.concatenate([class_scores, boxes], axis=-1) + + ct, data, indices = _op.vision.get_valid_counts( + data, score_threshold=score_threshold, id_index=-1, score_index=0 + ) + # reason why using get_valid_counts is for inference performance + # ONNX NMS doesn't have parameter top_k + top_k = -1 + # ONNX doesn't have class id for nms input + score_index = 0 + nms_ret = _op.vision.non_max_suppression( + data=data, + valid_count=ct, + indices=indices, + max_output_size=max_output_boxes_per_class, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False, + ) + nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) + batch_num = _op.expand_dims(_op.arange(_op.reshape(B, []), dtype="int64"), -1, 1) + batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64")) + batch_num = _op.expand_dims(batch_num, -1, 1) + class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64")) + onnx_output = _op.concatenate([batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1) + nms_size = _op.cast(nms_ret[1], "int64") + init_count = _op.const(np.array([0]), dtype="int64") + init_out = _op.reshape(_op.const([], dtype="int64"), [0, 3]) + + inner_loop_vals = inner_loop(init_count, B, nms_size, onnx_output, init_out) + new_out = _expr.TupleGetItem(inner_loop_vals, 4) + return [i + one, scores, boxes, B, C, S, _op.concatenate([out, new_out], axis=0)] + + outer_loop = _loops.while_loop(_outer_cond, [i, scores_var, boxes_var, B, C, S, out], _outer_body) + + # loop over classes + B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3) + init_count = _op.const(np.array([0]), dtype="int64") + init_out = _op.reshape(_op.const([], dtype="int64"), [0, 3]) + print(infer_type(scores)) + print(infer_type(boxes)) + loop_vals = outer_loop(init_count, scores, boxes, B, C, S, init_out) + print(infer_type(_expr.TupleGetItem(loop_vals, 6))) + return _expr.TupleGetItem(loop_vals, 6) + + @classmethod + def _impl_static(cls, inputs, attr, params): + # Get parameter values + boxes = inputs[0] + scores = inputs[1] + max_output_boxes_per_class = inputs["max_output_boxes_per_class"] + iou_threshold = inputs["iou_threshold"] + score_threshold = inputs["score_threshold"] if "center_point_box" in attr: assert ( attr["center_point_box"] == 0 @@ -2331,7 +2451,7 @@ def pad_last_dim(x): out = [] for i in range(C): class_scores = _op.strided_slice(scores, [0, i, 0], [B, i + 1, S], [1, 1, 1]) - class_scores = pad_last_dim(_op.squeeze(class_scores)) + class_scores = pad_last_dim(_op.squeeze(class_scores, )) data = _op.concatenate([class_scores, boxes], -1) ct, data, indices = _op.vision.get_valid_counts( data, score_threshold=score_threshold, id_index=-1, score_index=0 @@ -2391,7 +2511,7 @@ def pad_last_dim(x): dtype="int64", ), _op.reshape( - _op.strided_slice(nms_size, [batch, 0], [batch + 1, 1], [1, 1, 1]), [1] + _op.strided_slice(nms_size, [batch, 0], [batch + 1, 1], [1, 1]), [1] ), _op.const( np.array( diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 4a8c878eb634..623571060561 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -33,6 +33,7 @@ bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& att const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); + if (data == nullptr) return false; const auto& dshape = data->shape; ICHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; @@ -75,7 +76,9 @@ bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 6); const auto* data = types[0].as(); + if (data == nullptr) return false; const auto* valid_count = types[1].as(); + if (valid_count == nullptr) return false; const NonMaximumSuppressionAttrs* param = attrs.as(); const auto& dshape = data->shape; const auto& vshape = valid_count->shape; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 473b5d759272..c95cdc8286f9 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -261,6 +261,19 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In TVMRetValue rv; func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } +std::ostream& operator<<(std::ostream& out, const NDArray& array) { + static const PackedFunc* fprint = Registry::Get("relay._ndarray_repr"); + CHECK(fprint); + std::string data = (*fprint)(array); + out << data; + return out; +} + +void print(const ObjectRef& arg) { + if (arg->IsInstance()) { + std::cout << "\t" << Downcast(arg) << std::endl; + } +} void VirtualMachine::LoadExecutable(const Executable* exec) { ICHECK(exec) << "The executable is not created yet."; @@ -280,6 +293,7 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true); ICHECK(pf != nullptr) << "Cannot find function in module: " << packed_name; packed_funcs_[packed_index] = pf; + std::cout << packed_name << " -> " << packed_index < Date: Wed, 28 Oct 2020 13:06:14 -0600 Subject: [PATCH 04/10] working nms with dynamic shapes --- python/tvm/relay/frontend/onnx.py | 291 ++++++++++----------- tests/python/frontend/onnx/test_forward.py | 119 ++++++--- 2 files changed, 232 insertions(+), 178 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c726b4eeedcb..de5656283fc8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2329,49 +2329,56 @@ def _impl_v10(cls, inputs, attr, params): zero = _op.const(np.array([0]), dtype="int64") one = _op.const(np.array([1]), dtype="int64") - three = _op.const(np.array([1]), dtype="int64") - three_ones = _op.const(np.array([1, 1, 1]), dtype="int64") + three = _op.const(np.array([3]), dtype="int64") two_ones = _op.const(np.array([1, 1]), dtype="int64") - - batch = _expr.var("batch", shape=(1,), dtype="int64") - B = _expr.var("B", shape=(1,), dtype="int64") - nms_size = _expr.var("nms_size", shape=(_ty.Any(), _ty.Any()), dtype="int64") - onnx_output = _expr.var("onnx_output", shape=(_ty.Any(), _ty.Any(), 3), dtype="int64") - out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") - - def _inner_cond(batch, B, nms_size, onnx_output, out): - return _op.min(_op.less(batch, B)) - - def _inner_body(batch, B, nms_size, onnx_output, out): - start = _op.concatenate([batch, zero], axis=0) - end = _op.concatenate([batch + one, one], axis=0) - num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, two_ones), [1]) - start = _op.concatenate([batch, zero, zero], axis=0) - end = _op.concatenate([batch + one, num_valid_boxes, three], axis=0) - new_out = _op.squeeze(_op.strided_slice(onnx_output, start, end, three_ones), [0]) - #new_out = _op.const(np.array([1,1,1]).reshape([1,3]), dtype="int64") - print(infer_type(new_out), infer_type(out)) - return batch + one, B, nms_size, onnx_output, _op.concatenate([out, new_out], axis=0) - - inner_loop = _loops.while_loop(_inner_cond, [batch, B, nms_size, onnx_output, out], _inner_body) + three_ones = _op.const(np.array([1, 1, 1]), dtype="int64") + four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64") def pad_last_dim(x): return _op.expand_dims(x, -1, 1) - #Initial Values - - i = _expr.var("i", shape=(1, ), dtype="int64") + # First Loop Vars + i = _expr.var("i", shape=(1,), dtype="int64") scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype) boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype) + max_output_boxes_per_class_var = _expr.var( + "max_output_boxes_per_class_var", shape=(), dtype="int64" + ) + iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32") B = _expr.var("B", shape=(1,), dtype="int64") C = _expr.var("C", shape=(1,), dtype="int64") S = _expr.var("S", shape=(1,), dtype="int64") - out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") - - def _outer_cond(i, scores, boxes, B, C, S, out): + # Outputs of first loop should be padded nms values shape (B, C, 3) + onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") + # and sizes of valid outputs, shape (B, C, 1) + nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") + + def _first_cond( + i, + scores, + boxes, + B, + C, + S, + max_output_boxes_per_class, + iou_threshold, + onnx_out, + nms_size_out, + ): return _op.min(_op.less(i, C)) - def _outer_body(i, scores, boxes, B, C, S, out): + def _first_body( + i, + scores, + boxes, + B, + C, + S, + max_output_boxes_per_class, + iou_threshold, + onnx_out, + nms_size_out, + ): begin = _op.concatenate([zero, i, zero], axis=0) end = _op.concatenate([B, i + one, S], axis=0) class_scores = _op.strided_slice(scores, begin, end, three_ones) @@ -2401,133 +2408,121 @@ def _outer_body(i, scores, boxes, B, C, S, out): invalid_to_bottom=False, ) nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) - batch_num = _op.expand_dims(_op.arange(_op.reshape(B, []), dtype="int64"), -1, 1) + batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1) batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64")) batch_num = _op.expand_dims(batch_num, -1, 1) class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64")) - onnx_output = _op.concatenate([batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1) + new_onnx_out = _op.concatenate( + [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1 + ) + new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1) nms_size = _op.cast(nms_ret[1], "int64") - init_count = _op.const(np.array([0]), dtype="int64") - init_out = _op.reshape(_op.const([], dtype="int64"), [0, 3]) - - inner_loop_vals = inner_loop(init_count, B, nms_size, onnx_output, init_out) - new_out = _expr.TupleGetItem(inner_loop_vals, 4) - return [i + one, scores, boxes, B, C, S, _op.concatenate([out, new_out], axis=0)] + nms_size = _op.expand_dims(nms_size, 1, 1) + return [ + i + one, + scores, + boxes, + B, + C, + S, + max_output_boxes_per_class, + iou_threshold, + _op.concatenate([onnx_out, new_onnx_out], axis=1), + _op.concatenate([nms_size_out, nms_size], axis=1), + ] + + first_loop = _loops.while_loop( + _first_cond, + [ + i, + scores_var, + boxes_var, + B, + C, + S, + max_output_boxes_per_class_var, + iou_threshold_var, + onnx_out, + nms_size_out, + ], + _first_body, + ) - outer_loop = _loops.while_loop(_outer_cond, [i, scores_var, boxes_var, B, C, S, out], _outer_body) + # Second inner Loop Vars + i = _expr.var("i", shape=(1,), dtype="int64") + j = _expr.var("j", shape=(1,), dtype="int64") + B = _expr.var("B", shape=(1,), dtype="int64") + C = _expr.var("C", shape=(1,), dtype="int64") + # Outputs of first loop should be padded nms values shape (B, C, 3) + onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") + # and sizes of valid outputs, shape (B, C, 1) + nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") + out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") - # loop over classes - B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3) - init_count = _op.const(np.array([0]), dtype="int64") - init_out = _op.reshape(_op.const([], dtype="int64"), [0, 3]) - print(infer_type(scores)) - print(infer_type(boxes)) - loop_vals = outer_loop(init_count, scores, boxes, B, C, S, init_out) - print(infer_type(_expr.TupleGetItem(loop_vals, 6))) - return _expr.TupleGetItem(loop_vals, 6) + def _inner_cond(i, j, C, onnx_out, nms_size, out): + return _op.min(_op.less(j, C)) - @classmethod - def _impl_static(cls, inputs, attr, params): - # Get parameter values - boxes = inputs[0] - scores = inputs[1] - max_output_boxes_per_class = inputs["max_output_boxes_per_class"] - iou_threshold = inputs["iou_threshold"] - score_threshold = inputs["score_threshold"] - if "center_point_box" in attr: - assert ( - attr["center_point_box"] == 0 - ), "Only support center_point_box = 0 in onnx importer right now" + def _inner_body(i, j, C, onnx_out, nms_size, out): + start = _op.concatenate([i, j, zero], axis=0) + end = _op.concatenate([i + one, j + one, one], axis=0) + num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1]) + start = _op.concatenate([i, j, zero, zero], axis=0) + end = _op.concatenate([i + one, j + one, num_valid_boxes, three], axis=0) + new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1]) + return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0) - def pad_last_dim(x): - return _op.expand_dims(x, -1, 1) + inner_loop = _loops.while_loop( + _inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body + ) - if iou_threshold is None: - iou_threshold = 0.0 - if score_threshold is None: - score_threshold = 0.0 - # loop over classes - B, C, S = infer_shape(scores) - out = [] - for i in range(C): - class_scores = _op.strided_slice(scores, [0, i, 0], [B, i + 1, S], [1, 1, 1]) - class_scores = pad_last_dim(_op.squeeze(class_scores, )) - data = _op.concatenate([class_scores, boxes], -1) - ct, data, indices = _op.vision.get_valid_counts( - data, score_threshold=score_threshold, id_index=-1, score_index=0 - ) - # reason why using get_valid_counts is for inference performance - # NNX NMS doesn't have parameter top_k - top_k = -1 - # TF doesn't have class id for nms input - score_index = 0 - nms_ret = _op.vision.non_max_suppression( - data=data, - valid_count=ct, - indices=indices, - max_output_size=max_output_boxes_per_class, - iou_threshold=iou_threshold, - force_suppress=True, - top_k=top_k, - coord_start=1, - score_index=score_index, - id_index=-1, - return_indices=True, - invalid_to_bottom=False, - ) - nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) - onnx_output = _op.concatenate( - [ - pad_last_dim( - _op.broadcast_to( - pad_last_dim( - _op.arange( - _op.const(infer_shape(nms_padded_out)[0]), - dtype=infer_type(nms_padded_out).checked_type.dtype, - ), - ), - infer_shape(nms_ret[0]), - ), - ), - _op.broadcast_to( - _op.const(i, dtype=infer_type(nms_padded_out).checked_type.dtype), - infer_shape(nms_padded_out), - ), - nms_padded_out, - ], - -1, - ) - nms_size = _op.cast(nms_ret[1], "int64") - for batch in range(B): - start = [batch, 0, 0] - end = _op.concatenate( - [ - _op.const( - np.array( - [ - batch + 1, - ] - ), - dtype="int64", - ), - _op.reshape( - _op.strided_slice(nms_size, [batch, 0], [batch + 1, 1], [1, 1]), [1] - ), - _op.const( - np.array( - [ - 3, - ] - ), - dtype="int64", - ), - ], - 0, - ) - out += [_op.squeeze(_op.strided_slice(onnx_output, start, end, [1, 1, 1]), [0])] + # Second Outer Loop Vars + i = _expr.var("i", shape=(1,), dtype="int64") + j = _expr.var("j", shape=(1,), dtype="int64") + B = _expr.var("B", shape=(1,), dtype="int64") + C = _expr.var("C", shape=(1,), dtype="int64") + # Outputs of first loop should be padded nms values shape (B, C, 3) + onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") + # and sizes of valid outputs, shape (B, C, 1) + nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") + out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") + + def _outer_cond(i, B, C, onnx_out, nms_size_out, out): + return _op.min(_op.less(i, B)) + + def _outer_body(i, B, C, onnx_out, nms_size_out, out): + init_count = _op.const(np.array([0]), dtype="int64") + inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out) + return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5) + + outer_loop = _loops.while_loop( + _outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body + ) + + B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3) + init_count = _op.const(np.array([0]), dtype="int64") + init_onnx_out = _op.const([], dtype="int64") + init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, zero, S, three], 0)) + init_nms_size_out = _op.const([], dtype="int64") + init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, zero, one], 0)) + loop_vals = first_loop( + init_count, + scores, + boxes, + B, + C, + S, + max_output_boxes_per_class, + iou_threshold, + init_onnx_out, + init_nms_size_out, + ) + onnx_output = _expr.TupleGetItem(loop_vals, 8) + nms_size_output = _expr.TupleGetItem(loop_vals, 9) - out = [out[i * B + j] for j in range(B) for i in range(C)] - return _op.concatenate(out, axis=0) + init_count = _op.const(np.array([0]).astype("int64"), dtype="int64") + init_out = _op.const(np.array([]).reshape([0, 3]).astype("int64"), dtype="int64") + loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out) + return _expr.TupleGetItem(loop_vals, 5) # compatible operators that do NOT require any conversion. diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0fcf34e120ef..fd2e96b9ab8b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -53,10 +53,12 @@ def get_tvm_output_with_vm( mod, params = relay.frontend.from_onnx( graph_def, shape_dict, opset=opset, freeze_params=freeze_params ) + if convert_to_static: from tvm.relay import transform mod = transform.DynamicToStatic()(mod) + ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) result = ex.evaluate()(*input_data) if isinstance(result, tvm.runtime.NDArray): @@ -146,8 +148,8 @@ def flatten(out): for target in targets: ctx = tvm.context(target, 0) - print(target, ctx) - if target == "cuda" or target == "nvptx":continue + if target == "cuda" or target == "nvptx": + continue if use_vm: tvm_out = get_tvm_output_with_vm( model, @@ -160,7 +162,6 @@ def flatten(out): ) else: tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, dtype, opset=opset) - print(ort_out, tvm_out) tvm.testing.assert_allclose(flatten(ort_out), flatten(tvm_out), rtol=rtol, atol=atol) @@ -3689,29 +3690,90 @@ def verify_roi_align( verify_roi_align((5, 4, 16, 14), 32, 7, 7, sampling_ratio=1, spatial_scale=1.0) verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) -#@tvm.testing.uses_gpu + +# @tvm.testing.uses_gpu def test_non_max_suppression(): - boxes = np.array( [0.0, 0.0, 0.3, 0.3, - 0.0, 0.0, 0.4, 0.4, - 0.0, 0.0, 0.5, 0.5, - 0.5, 0.5, 0.9, 0.9, - 0.5, 0.5, 1.0, 1.0, - - 0.0, 0.0, 0.3, 0.3, - 0.0, 0.0, 0.4, 0.4, - 0.5, 0.5, 0.95, 0.95, - 0.5, 0.5, 0.96, 0.96, - 0.5, 0.5, 1.0, 1.0]).reshape([2, 5, 4]).astype("float32") - scores = np.array( [0.1, 0.2, 0.6, 0.3, 0.9, - 0.1, 0.2, 0.6, 0.3, 0.9, - - 0.1, 0.2, 0.6, 0.3, 0.9, - 0.1, 0.2, 0.6, 0.3, 0.9]).reshape([2, 2, 5]).astype("float32") + boxes = ( + np.array( + [ + 0.0, + 0.0, + 0.3, + 0.3, + 0.0, + 0.0, + 0.4, + 0.4, + 0.0, + 0.0, + 0.5, + 0.5, + 0.5, + 0.5, + 0.9, + 0.9, + 0.5, + 0.5, + 1.0, + 1.0, + 0.0, + 0.0, + 0.3, + 0.3, + 0.0, + 0.0, + 0.4, + 0.4, + 0.5, + 0.5, + 0.95, + 0.95, + 0.5, + 0.5, + 0.96, + 0.96, + 0.5, + 0.5, + 1.0, + 1.0, + ] + ) + .reshape([2, 5, 4]) + .astype("float32") + ) + scores = ( + np.array( + [ + 0.1, + 0.2, + 0.6, + 0.3, + 0.9, + 0.1, + 0.2, + 0.6, + 0.3, + 0.9, + 0.1, + 0.2, + 0.6, + 0.3, + 0.9, + 0.1, + 0.2, + 0.6, + 0.3, + 0.9, + ] + ) + .reshape([2, 2, 5]) + .astype("float32") + ) max_output_boxes_per_class = np.array(2).astype("int64") iou_threshold = np.array(0.8).astype("float32") output_dims = [8, 3] - #test.AddInput("iou_threshold", {}, {0.8f}); - #test.AddOutput("selected_indices", {8, 3}, + # test.AddInput("iou_threshold", {}, {0.8f}); + # test.AddOutput("selected_indices", {8, 3}, # {0L, 0L, 4L, # 0L, 0L, 2L, # 0L, 1L, 4L, @@ -3735,22 +3797,19 @@ def test_non_max_suppression(): helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes.shape), helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores.shape), helper.make_tensor_value_info( - "max_output_boxes_per_class", - TensorProto.INT64, - max_output_boxes_per_class.shape), - helper.make_tensor_value_info( - "iou_threshold", - TensorProto.FLOAT, - iou_threshold.shape + "max_output_boxes_per_class", TensorProto.INT64, max_output_boxes_per_class.shape ), + helper.make_tensor_value_info("iou_threshold", TensorProto.FLOAT, iou_threshold.shape), ], outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, output_dims)], ) model = helper.make_model(graph, producer_name="nms_test") + verify_with_ort_with_inputs( + model, [boxes, scores, max_output_boxes_per_class, iou_threshold], use_vm=True + ) - verify_with_ort_with_inputs(model, [boxes, scores, max_output_boxes_per_class, iou_threshold], use_vm=True) def verify_cond_loop(): y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) From f26c1fc5facf1eb2eaa5d63c196171f3314a16ce Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 28 Oct 2020 14:37:14 -0600 Subject: [PATCH 05/10] add a test with dynamic score_threshold and pass it --- include/tvm/relay/attrs/vision.h | 9 +- python/tvm/relay/frontend/onnx.py | 32 +++- python/tvm/relay/op/strategy/generic.py | 4 +- python/tvm/relay/op/vision/nms.py | 2 + python/tvm/topi/vision/nms.py | 7 +- src/relay/op/vision/nms.cc | 12 +- tests/python/frontend/onnx/test_forward.py | 193 +++++++++------------ 7 files changed, 127 insertions(+), 132 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index ff5af8032e8d..ca2c4a2b837d 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -73,14 +73,12 @@ struct MultiBoxTransformLocAttrs : public tvm::AttrsNode { - double score_threshold; + Optional score_threshold; int id_index; int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { - TVM_ATTR_FIELD(score_threshold) - .set_default(0.0) - .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(score_threshold).describe("Lower limit of score for valid bounding boxes."); TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id."); TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes."); } @@ -100,8 +98,7 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); + ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) return false; const auto& dshape = data->shape; @@ -45,17 +45,16 @@ bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& att fields.push_back(TensorType(oshape_indices, DataType::Int(32))); // assign output type - reporter->Assign(types[1], TupleType(Array(fields))); + reporter->Assign(types[2], TupleType(Array(fields))); return true; } -Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) { +Expr MakeGetValidCounts(Expr data, Expr score_threshold, int id_index, int score_index) { auto attrs = make_object(); - attrs->score_threshold = score_threshold; attrs->id_index = id_index; attrs->score_index = score_index; static const Op& op = Op::Get("vision.get_valid_counts"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, score_threshold}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts); @@ -65,8 +64,9 @@ RELAY_REGISTER_OP("vision.get_valid_counts") a score threshold. Also moves valid boxes to the top of input data. )doc" TVM_ADD_FILELINE) - .set_num_inputs(1) + .set_num_inputs(2) .add_argument("data", "Tensor", "Input data.") + .add_argument("score_threshold", "Tensor", "Minimum Score.") .set_support_level(5) .add_type_rel("GetValidCount", GetValidCountRel); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index fd2e96b9ab8b..e9c7577ec809 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3693,122 +3693,99 @@ def verify_roi_align( # @tvm.testing.uses_gpu def test_non_max_suppression(): - boxes = ( - np.array( - [ - 0.0, - 0.0, - 0.3, - 0.3, - 0.0, - 0.0, - 0.4, - 0.4, - 0.0, - 0.0, - 0.5, - 0.5, - 0.5, - 0.5, - 0.9, - 0.9, - 0.5, - 0.5, - 1.0, - 1.0, - 0.0, - 0.0, - 0.3, - 0.3, - 0.0, - 0.0, - 0.4, - 0.4, - 0.5, - 0.5, - 0.95, - 0.95, - 0.5, - 0.5, - 0.96, - 0.96, - 0.5, - 0.5, - 1.0, - 1.0, - ] - ) - .reshape([2, 5, 4]) - .astype("float32") - ) - scores = ( - np.array( - [ - 0.1, - 0.2, - 0.6, - 0.3, - 0.9, - 0.1, - 0.2, - 0.6, - 0.3, - 0.9, - 0.1, - 0.2, - 0.6, - 0.3, - 0.9, - 0.1, - 0.2, - 0.6, - 0.3, - 0.9, - ] - ) - .reshape([2, 2, 5]) - .astype("float32") - ) - max_output_boxes_per_class = np.array(2).astype("int64") - iou_threshold = np.array(0.8).astype("float32") - output_dims = [8, 3] - # test.AddInput("iou_threshold", {}, {0.8f}); - # test.AddOutput("selected_indices", {8, 3}, - # {0L, 0L, 4L, - # 0L, 0L, 2L, - # 0L, 1L, 4L, - # 0L, 1L, 2L, - - # 1L, 0L, 4L, - # 1L, 0L, 1L, - # 1L, 1L, 4L, - # 1L, 1L, 1L}); - node = helper.make_node( - "NonMaxSuppression", - inputs=["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"], - outputs=["Y"], - center_point_box=0, - ) - - graph = helper.make_graph( - [node], - "nms_test", - inputs=[ + def verify_nms( + boxes, scores, max_ouput_boxes_per_class, iou_threshold, score_threshold, output_dims + ): + input_names = ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"] + input_nodes = [ helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes.shape), helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores.shape), helper.make_tensor_value_info( "max_output_boxes_per_class", TensorProto.INT64, max_output_boxes_per_class.shape ), helper.make_tensor_value_info("iou_threshold", TensorProto.FLOAT, iou_threshold.shape), - ], - outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, output_dims)], - ) + ] + inputs = [boxes, scores, max_output_boxes_per_class, iou_threshold] + if score_threshold is not None: + input_names.append("score_threshold") + input_nodes.append( + helper.make_tensor_value_info( + "score_threshold", TensorProto.FLOAT, score_threshold.shape + ) + ) + inputs.append(score_threshold) + node = helper.make_node( + "NonMaxSuppression", + inputs=input_names, + outputs=["Y"], + center_point_box=0, + ) - model = helper.make_model(graph, producer_name="nms_test") + graph = helper.make_graph( + [node], + "nms_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, output_dims)], + ) - verify_with_ort_with_inputs( - model, [boxes, scores, max_output_boxes_per_class, iou_threshold], use_vm=True - ) + model = helper.make_model(graph, producer_name="nms_test") + + verify_with_ort_with_inputs(model, inputs, use_vm=True) + + print("start first test") + boxes = np.array( + [ + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.0, 0.0, 0.5, 0.5], + [0.5, 0.5, 0.9, 0.9], + [0.5, 0.5, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.5, 0.5, 0.95, 0.95], + [0.5, 0.5, 0.96, 0.96], + [0.5, 0.5, 1.0, 1.0], + ], + ] + ).astype("float32") + + scores = np.array( + [ + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + ] + ).astype("float32") + max_output_boxes_per_class = np.array(2).astype("int64") + iou_threshold = np.array(0.8).astype("float32") + output_dims = [8, 3] + verify_nms(boxes, scores, max_output_boxes_per_class, iou_threshold, None, output_dims) + print("end first test") + + print("start second test") + boxes = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, -0.1, 1.0, 0.9], + [0.0, 10.0, 1.0, 11.0], + [0.0, 10.1, 1.0, 11.1], + [0.0, 100.0, 1.0, 101.0], + ] + ] + ).astype(np.float32) + scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32) + max_output_boxes_per_class = np.array([3]).astype(np.int64) + iou_threshold = np.array([0.5]).astype(np.float32) + score_threshold = np.array([0.4]).astype(np.float32) + output_dims = [2, 3] + verify_nms( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_dims + ) + print("end second test") def verify_cond_loop(): From 3a334ce8265fa2bea6273bb1426140c971f6b929 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 30 Oct 2020 16:13:02 -0700 Subject: [PATCH 06/10] Fix type checking in lambda lift --- src/relay/backend/vm/compiler.cc | 1 + src/relay/backend/vm/lambda_lift.cc | 43 ++++++++++++++++++---- src/relay/op/tensor/transform.h | 17 +++++++-- tests/python/frontend/onnx/test_forward.py | 9 +++-- 4 files changed, 54 insertions(+), 16 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index f652644afa3c..bed2510cdf3c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1070,6 +1070,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::FuseOps()); pass_seqs.push_back(transform::ToANormalForm()); + pass_seqs.push_back(transform::InferType()); pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::InlinePrimitives()); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index f21d0967701a..d32f07ff1d65 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -111,17 +111,27 @@ class LambdaLifter : public ExprMutator { } captured_vars.push_back(var); } + + Array typed_captured_vars; + Map rebinding_map; + for (auto free_var : captured_vars) { + auto var = Var(free_var->name_hint(), free_var->checked_type()); + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); + } + if (recursive) { if (!captured_vars.empty()) { - Array fvs; - for (auto fv : captured_vars) { - fvs.push_back(fv); + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + lambda_map_.emplace(letrec_.back(), Call(global, fvs)); + } else { + lambda_map_.emplace(letrec_.back(), global); } - lambda_map_.emplace(letrec_.back(), Call(global, fvs)); - } else { - lambda_map_.emplace(letrec_.back(), global); - } } + auto body = Downcast(ExprMutator::VisitExpr_(func_node)); // When performing this optimization there are two cases. @@ -150,7 +160,23 @@ class LambdaLifter : public ExprMutator { if (captured_vars.size() == 0 && free_type_vars.size() == 0) { lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); } else { - lifted_func = Function(captured_vars, body, func->func_type_annotation(), free_type_vars); + // When a closure is locally bound in a program, we have its full type information + // avalible to us. + // + // If we lift the closure out of its bound context it may have free variables which + // do not have type annotations. + // + // In this case we first type check the program assigning a type to all sub-expressions. + // + // We then change the un-annotated free variables into annotated free variables, use + // bind to go from unannotated free variables -> annotated free variables and then + // construct the "closure" function with fully annotated arguments, no longer relying + // on type inference. + auto before = Downcast(body)->params.size(); + auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, func->type_params, func->attrs, func->span); + auto after = Downcast(rebound_body)->params.size(); + CHECK_EQ(before, after); + lifted_func = Function(typed_captured_vars, rebound_body, func->func_type_annotation(), free_type_vars); lifted_func = MarkClosure(lifted_func); } @@ -164,6 +190,7 @@ class LambdaLifter : public ExprMutator { global = module_->GetGlobalVar(name); } else { // Add the lifted function to the module. + std::cout << AsText(lifted_func) << std::endl; module_->Add(global, lifted_func); } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 4173d57a84de..bc6e40740da8 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -44,21 +44,30 @@ template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] - ICHECK_EQ(types.size(), 2); + ICHECK_EQ(types.size(), 2) + << "the arity of concatenate is 2, not " << types.size(); /* If we receive a tuple we can continue, if we receive * anything but an incomplete type we should signal an * error. */ const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { - throw Error( - ErrorBuilder() << "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0])); + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); + return false; } else if (types[0].as() != nullptr) { return false; } const auto* param = attrs.as(); + if (param == nullptr) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) << "the call attributes are not defined"); + return false; + } + if (tensor_tuple->fields[0].as()) { return false; } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e9c7577ec809..8328eef9c1d7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -54,6 +54,11 @@ def get_tvm_output_with_vm( graph_def, shape_dict, opset=opset, freeze_params=freeze_params ) + from tvm.relay import transform + # print(mod.astext(show_meta_data=True)) + # self.mod = transform.AnnotateSpans()(mod) + # print(mod.astext(show_meta_data=False)) + if convert_to_static: from tvm.relay import transform @@ -3732,7 +3737,6 @@ def verify_nms( verify_with_ort_with_inputs(model, inputs, use_vm=True) - print("start first test") boxes = np.array( [ [ @@ -3762,9 +3766,7 @@ def verify_nms( iou_threshold = np.array(0.8).astype("float32") output_dims = [8, 3] verify_nms(boxes, scores, max_output_boxes_per_class, iou_threshold, None, output_dims) - print("end first test") - print("start second test") boxes = np.array( [ [ @@ -3785,7 +3787,6 @@ def verify_nms( verify_nms( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_dims ) - print("end second test") def verify_cond_loop(): From 96bbb1bf16a189720e087727e3e0228bc2278e04 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 3 Nov 2020 15:15:41 -0700 Subject: [PATCH 07/10] ONNX NMS working on GPU, had to remove threading from some kernels fix lint fix lambda lift tests fix unit tests respond to review comments fix lint --- python/tvm/relay/backend/_backend.py | 7 - python/tvm/relay/frontend/onnx.py | 58 ++- python/tvm/relay/op/vision/nms.py | 6 +- python/tvm/topi/cuda/nms.py | 427 +++++++++++--------- python/tvm/topi/cuda/sort.py | 8 +- src/relay/backend/vm/lambda_lift.cc | 26 +- src/relay/op/tensor/transform.h | 9 +- src/runtime/vm/vm.cc | 15 - tests/python/frontend/onnx/test_forward.py | 12 +- tests/python/relay/test_op_level5.py | 4 +- tests/python/relay/test_pass_lambda_lift.py | 3 + 11 files changed, 304 insertions(+), 271 deletions(-) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index bf4577a8924a..65b0c0ba87c7 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -88,13 +88,6 @@ def _tensor_value_repr(tvalue): return str(tvalue.data.asnumpy()) -@tvm._ffi.register_func("relay._ndarray_repr") -def _tensor_constant_repr(tvalue): - tmp = tvalue.asnumpy() - return "NDArray of shape " + str(tmp.shape) + " and dtype " + str(tmp.dtype) +"\n\t" + str(tmp) - - - @tvm._ffi.register_func("relay._constant_repr") def _tensor_constant_repr(tvalue): dtype = tvm.runtime.DataType(tvalue.data.dtype) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 449e60f1d740..23102aaa9d32 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2308,6 +2308,17 @@ class NonMaxSuppression(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): + """ + High level note: ONNX implements what TF calls combined_non_max_suppression + It passes in scores for each box for every class in the output and expects boxes to be + analyzed for each class independently + + It also asks for the data to be returned in a particular format. + + To support these, we implement a series of lops: + The first loop splits over class number, performs NMS, and collects the outputs. + The second (nested) loop takes the outputs and transforms them into the format ONNX wants + """ # Get parameter values boxes = inputs[0] scores = inputs[1] @@ -2337,17 +2348,17 @@ def conditionally_squeeze_scalar(x): max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class) iou_threshold = conditionally_squeeze_scalar(iou_threshold) score_threshold = conditionally_squeeze_scalar(score_threshold) + + ## prepare utility constants zero = _op.const(np.array([0]), dtype="int64") one = _op.const(np.array([1]), dtype="int64") + two = _op.const(np.array([2]), dtype="int64") three = _op.const(np.array([3]), dtype="int64") - two_ones = _op.const(np.array([1, 1]), dtype="int64") three_ones = _op.const(np.array([1, 1, 1]), dtype="int64") four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64") - def pad_last_dim(x): - return _op.expand_dims(x, -1, 1) - - # First Loop Vars + ## First loop: split by class and perform NMS + # Create Loop Vars i = _expr.var("i", shape=(1,), dtype="int64") scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype) boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype) @@ -2359,7 +2370,7 @@ def pad_last_dim(x): B = _expr.var("B", shape=(1,), dtype="int64") C = _expr.var("C", shape=(1,), dtype="int64") S = _expr.var("S", shape=(1,), dtype="int64") - # Outputs of first loop should be padded nms values shape (B, C, 3) + # Outputs of first loop should be padded nms values shape (B, C, S, 3) onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") # and sizes of valid outputs, shape (B, C, 1) nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") @@ -2377,6 +2388,7 @@ def _first_cond( onnx_out, nms_size_out, ): + # Loop over classes, end when i == C return _op.min(_op.less(i, C)) def _first_body( @@ -2392,12 +2404,15 @@ def _first_body( onnx_out, nms_size_out, ): + # slice to get current class begin = _op.concatenate([zero, i, zero], axis=0) end = _op.concatenate([B, i + one, S], axis=0) class_scores = _op.strided_slice(scores, begin, end, three_ones) class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1) + # combine scores and boxes data = _op.concatenate([class_scores, boxes], axis=-1) + # get valid counts ct, data, indices = _op.vision.get_valid_counts( data, score_threshold=score_threshold, id_index=-1, score_index=0 ) @@ -2406,6 +2421,7 @@ def _first_body( top_k = -1 # ONNX doesn't have class id for nms input score_index = 0 + # perform nms on current class nms_ret = _op.vision.non_max_suppression( data=data, valid_count=ct, @@ -2420,6 +2436,7 @@ def _first_body( return_indices=True, invalid_to_bottom=False, ) + # partially prepare ONNX output format by labeling batch_num, class_id nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1) batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64")) @@ -2429,6 +2446,7 @@ def _first_body( [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1 ) new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1) + # store valid nms outputs for this class nms_size = _op.cast(nms_ret[1], "int64") nms_size = _op.expand_dims(nms_size, 1, 1) return [ @@ -2445,6 +2463,7 @@ def _first_body( _op.concatenate([nms_size_out, nms_size], axis=1), ] + # create the first loop first_loop = _loops.while_loop( _first_cond, [ @@ -2463,6 +2482,8 @@ def _first_body( _first_body, ) + ## Second loop slices outputs of the first loop for valid boxes and + ## concats in the order ONNX wants # Second inner Loop Vars i = _expr.var("i", shape=(1,), dtype="int64") j = _expr.var("j", shape=(1,), dtype="int64") @@ -2475,14 +2496,17 @@ def _first_body( out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") def _inner_cond(i, j, C, onnx_out, nms_size, out): + # inner loop over number of classes return _op.min(_op.less(j, C)) def _inner_body(i, j, C, onnx_out, nms_size, out): - start = _op.concatenate([i, j, zero], axis=0) - end = _op.concatenate([i + one, j + one, one], axis=0) + # slice to get current batch and class for valid box indicator + start = _op.concatenate([i, j + one, zero], axis=0) + end = _op.concatenate([i + one, j + two, one], axis=0) num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1]) - start = _op.concatenate([i, j, zero, zero], axis=0) - end = _op.concatenate([i + one, j + one, num_valid_boxes, three], axis=0) + # slice to get current batch, class, and valid outputs + start = _op.concatenate([i, j + one, zero, zero], axis=0) + end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0) new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1]) return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0) @@ -2502,23 +2526,27 @@ def _inner_body(i, j, C, onnx_out, nms_size, out): out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") def _outer_cond(i, B, C, onnx_out, nms_size_out, out): + # Outer loop is over batch size return _op.min(_op.less(i, B)) def _outer_body(i, B, C, onnx_out, nms_size_out, out): + # Outer loop just calls inner loop init_count = _op.const(np.array([0]), dtype="int64") inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out) return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5) + # Create the second loop outer_loop = _loops.while_loop( _outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body ) + # Call the first loop, perform NMS B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3) init_count = _op.const(np.array([0]), dtype="int64") - init_onnx_out = _op.const([], dtype="int64") - init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, zero, S, three], 0)) - init_nms_size_out = _op.const([], dtype="int64") - init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, zero, one], 0)) + init_onnx_out = _op.const([1], dtype="int64") + init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0)) + init_nms_size_out = _op.const([1], dtype="int64") + init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0)) loop_vals = first_loop( init_count, scores, @@ -2535,9 +2563,11 @@ def _outer_body(i, B, C, onnx_out, nms_size_out, out): onnx_output = _expr.TupleGetItem(loop_vals, 9) nms_size_output = _expr.TupleGetItem(loop_vals, 10) + # Call the second loop, rework outputs into correct form init_count = _op.const(np.array([0]).astype("int64"), dtype="int64") init_out = _op.const(np.array([]).reshape([0, 3]).astype("int64"), dtype="int64") loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out) + return _expr.TupleGetItem(loop_vals, 5) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 1faec25a7f54..0a3df40b99df 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -48,7 +48,7 @@ def get_valid_counts(data, score_threshold, id_index=0, score_index=1): out_indices: relay.Expr Indices in input data """ - if isinstance(score_threshold, float): + if not isinstance(score_threshold, expr.Expr): score_threshold = expr.const(score_threshold, "float32") return expr.TupleWrapper( _make.get_valid_counts(data, score_threshold, id_index, score_index), 3 @@ -128,9 +128,9 @@ def non_max_suppression( If return_indices is True, return relay.Tuple of two 2-D tensors, with shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively. """ - if isinstance(max_output_size, int): + if not isinstance(max_output_size, expr.Expr): max_output_size = expr.const(max_output_size, "int32") - if isinstance(iou_threshold, float): + if not isinstance(iou_threshold, expr.Expr): iou_threshold = expr.const(iou_threshold, "float32") out = _make.non_max_suppression( data, diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index d51eb5ce1d11..12ee057d040a 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -52,64 +52,66 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) -def rearrange_indices_out_ir(data, out, valid_box_count): +def rearrange_indices_out_ir(data, output, valid_box_count): """Hybrid routine to rearrange nms output to move all valid entries to top. Parameters ---------- data : tvm.te.Tensor or numpy NDArray + NMS output. 3-D tensor with shape + [batch_size, num_anchors, 6] or + [batch_size, num_anchors, 5], or 2-D tensor with shape [batch_size, num_anchors]. + one: tvm.tir.const + Constant one with the same dtype as data. + + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. + + num_anchors: tvm.tir.IntImm or tvm.tir.Var + Number of anchors. Returns ------- - stmt : Stmt - The result IR statement. + output : tvm.te.Tensor or numpy NDArray + 2-D tensor with shape [batch_size, num_anchors]. + + valid_box_count : tvm.te.Tensor or numpy NDArray + Tensor with shape [batch_size, 1], indicates + the valid number of boxes. """ batch_size = data.shape[0] num_anchors = data.shape[1] ib = tvm.tir.ir_builder.create() + data = ib.buffer_ptr(data) - out = ib.buffer_ptr(out) valid_box_count = ib.buffer_ptr(valid_box_count) - - one_count = tvm.tir.const(1, dtype="int32") - atomic_add_return = ib.allocate( - valid_box_count.dtype, (batch_size,), name="atomic_add_return", scope="local" - ) - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = (batch_size * num_anchors) // nthread_tx + 2 - - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - with ib.for_range(0, len_inner_for, name="i") as i: - idx = tx * len_inner_for + i - batch_idx = idxd(idx, num_anchors) - with ib.if_scope(idx < batch_size): - valid_box_count[idx] = 0 - with ib.if_scope(idx < batch_size * num_anchors): - with ib.if_scope(data[idx] >= 0): - atomic_add_return[batch_idx] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]), - one_count, - ) - out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = data[idx] - with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] < -num_anchors)): - atomic_add_return[batch_idx] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]), - one_count, - ) - out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0 - - with ib.if_scope(idxm(idx, num_anchors) >= valid_box_count[batch_idx]): - out[idx] = -1 + output = ib.buffer_ptr(output) + + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", batch_size) + valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local") + valid_idx[0] = 0 + with ib.for_range(0, num_anchors, name="j") as j: + with ib.if_scope(data[i, j] >= 0): + with ib.if_scope(data[i, j] > num_anchors): + output[i, valid_idx[0]] = 0 + valid_idx[0] = valid_idx[0] + 1 + with ib.else_scope(): + output[i, valid_idx[0]] = data[i, j] + valid_idx[0] = valid_idx[0] + 1 + with ib.else_scope(): + with ib.if_scope(data[i, j] < -num_anchors): + output[i, valid_idx[0]] = 0 + valid_idx[0] = valid_idx[0] + 1 + with ib.if_scope(j >= valid_idx[0]): + output[i, j] = -1 + valid_box_count[i, 0] = valid_idx[0] return ib.get() @@ -132,7 +134,7 @@ def get_valid_counts_ir( flag : Buffer 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - score_threshold : float32 + score_threshold : Buffer or float32 Lower limit of score for valid bounding boxes. id_index : optional, int @@ -157,47 +159,44 @@ def get_valid_counts_ir( valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) out_indices = ib.buffer_ptr(out_indices) - atomic_add_return = ib.allocate( - valid_count.dtype, (1,), name="atomic_add_return", scope="local" - ) - one_count = tvm.tir.const(1, dtype=valid_count.dtype) one = tvm.tir.const(1, dtype=out.dtype) - score_threshold = tvm.tir.FloatImm("float32", score_threshold) + if isinstance(score_threshold, float): + score_threshold = tvm.tir.FloatImm("float32", score_threshold) id_index = tvm.tir.IntImm("int32", id_index) score_index = tvm.tir.IntImm("int32", score_index) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = batch_size * num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - idxd = tvm.tir.indexdiv - - # initialize valid_count - with ib.if_scope(tid < batch_size): - valid_count[tid] = 0 - with ib.if_scope(tid < batch_size * num_anchors): - i = idxd(tid, num_anchors) - with ib.if_scope( - tvm.tir.all( - data[tid * elem_length + score_index] > score_threshold, - tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0), - ) - ): - atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_count[i]), one_count - ) - with ib.for_range(0, elem_length) as k: - out[tid * elem_length + k] = data[tid * elem_length + k] - out_indices[tid + k] = tid + k - with ib.else_scope(): - with ib.for_range(0, elem_length) as k: - out[tid * elem_length + k] = -one - out_indices[tid + k] = -one_count - + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = batch_size // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + valid_count[tid] = 0 + i = tid + with ib.for_range(0, num_anchors) as j: + score = data[(i * num_anchors + j) * elem_length + score_index] + with ib.if_scope( + tvm.tir.all( + score > score_threshold, + tvm.tir.any( + id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0 + ), + ) + ): + with ib.for_range(0, elem_length) as k: + out[(i * num_anchors + valid_count[i]) * elem_length + k] = data[ + (i * num_anchors + j) * elem_length + k + ] + out_indices[i * num_anchors + valid_count[i]] = j + valid_count[i] += 1 + with ib.if_scope(j >= valid_count[i]): + with ib.for_range(0, elem_length) as k: + out[(i * num_anchors + j) * elem_length + k] = -one + out_indices[i * num_anchors + j] = -1 return ib.get() @@ -210,7 +209,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): data : tvm.te.Tensor Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. - score_threshold : optional, float + score_threshold : optional, tvm.te.Tensor or float Lower limit of score for valid bounding boxes. id_index : optional, int @@ -277,12 +276,19 @@ def nms_ir( data : Buffer Buffer of output boxes with class and score. - sort_index : Buffer + sorted_index : Buffer Buffer of output box indexes sorted by score. valid_count : Buffer Buffer of number of valid output boxes. + indices : Buffer + indices in original tensor, with shape [batch_size, num_anchors], + represents the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the second + dimension are like the output of arange(num_anchors) if get_valid_counts + is not used before non_max_suppression. + out : Buffer Output buffer. @@ -308,33 +314,50 @@ def nms_ir( score_index : optional, int Index of the scores/confidence of boxes. + return_indices : boolean + Whether to return box indices in input data. + Returns ------- stmt : Stmt The result IR statement. """ - def calculate_overlap(out_tensor, box_a_idx, box_b_idx): - """Calculate overlap of two boxes.""" - w = tvm.te.max( - 0.0, - tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx]), + def get_boundaries(output, box_idx): + l = tvm.te.min( + output[box_idx], + output[box_idx + 2], + ) + t = tvm.te.min( + output[box_idx + 1], + output[box_idx + 3], ) - h = tvm.te.max( - 0.0, - tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]), + r = tvm.te.max( + output[box_idx], + output[box_idx + 2], ) - i = w * h - u = ( - (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) - * (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) - + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) - * (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - - i + b = tvm.te.max( + output[box_idx + 1], + output[box_idx + 3], ) - return tvm.tir.Select(u <= 0.0, 0.0, i / u) + return l, t, r, b + + def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes.""" + a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) + + # Overlapping width and height + w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) + h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) + + # Overlapping area + area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return tvm.tir.Select(u <= 0.0, 0.0, area / u) batch_size = data.shape[0] num_anchors = data.shape[1] @@ -345,108 +368,117 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): data = ib.buffer_ptr(data) sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) + indices = ib.buffer_ptr(indices) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) - indices = ib.buffer_ptr(indices) num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - j = bx * max_threads + tx - - iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) top_k = tvm.tir.IntImm("int32", top_k) coord_start = tvm.tir.IntImm("int32", coord_start) id_index = tvm.tir.IntImm("int32", id_index) score_index = tvm.tir.IntImm("int32", score_index) force_suppress = tvm.tir.IntImm("int32", 1 if force_suppress else 0) - with ib.for_range(0, batch_size, for_type="unroll") as i: - base_idx = i * num_anchors * box_data_length - with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - # Reorder output - nkeep = if_then_else( - tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] - ) - with ib.if_scope(j < nkeep): - with ib.for_range(0, box_data_length) as k: - out[(base_idx + j * box_data_length + k)] = data[ - (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) - ] - box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] - with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])): - with ib.if_scope(j < valid_count[i] - nkeep): + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", 1) + + with ib.for_range(0, batch_size) as i: + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + # Reorder output + nkeep = if_then_else( + tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] + ) + with ib.for_range(0, nkeep) as j: with ib.for_range(0, box_data_length) as k: - out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 - box_indices[i * num_anchors + (j + nkeep)] = -1 - # Apply nms - with ib.for_range(0, valid_count[i]) as k: - offset_k = k * box_data_length - with ib.if_scope( - tvm.tir.all( - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), - ) - ): - with ib.if_scope(j < valid_count[i]): - offset_j = j * box_data_length + out[(base_idx + j * box_data_length + k)] = data[ + (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) + ] + box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] + with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])): + with ib.for_range(0, valid_count[i] - nkeep) as j: + with ib.for_range(0, box_data_length) as k: + out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + (j + nkeep)] = -1 + # Apply nms + with ib.for_range(0, valid_count[i]) as j: + with ib.for_range(0, j) as k: + offset_k = k * box_data_length with ib.if_scope( tvm.tir.all( - j > k, - out[base_idx + offset_j + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], - ), + out[base_idx + offset_k + score_index] > 0, + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), ) ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_j + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_j + id_index] = -1.0 - box_indices[i * num_anchors + j] = -1 - with ib.else_scope(): - with ib.if_scope(j < valid_count[i]): - offset_j = j * box_data_length + offset_j = j * box_data_length + with ib.if_scope( + tvm.tir.all( + j > k, + out[base_idx + offset_k + score_index] > 0, + tvm.tir.any( + id_index < 0, out[base_idx + offset_j + id_index] >= 0 + ), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] + == out[base_idx + offset_j + id_index], + ), + ) + ): + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_j + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_j + id_index] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.for_range(0, box_data_length) as k: + out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] + box_indices[i * num_anchors + j] = j + # Set invalid entry to be -1 + with ib.for_range(0, num_anchors - valid_count[i]) as j: with ib.for_range(0, box_data_length) as k: - out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] - box_indices[i * num_anchors + j] = j - # Set invalid entry to be -1 - with ib.if_scope(j < num_anchors - valid_count[i]): - with ib.for_range(0, box_data_length) as k: - out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 - box_indices[i * num_anchors + j + valid_count[i]] = -1 - # Only return max_output_size number of valid boxes - num_valid_boxes[0] = 0 - with ib.if_scope(max_output_size > 0): - with ib.if_scope(j < valid_count[i]): - offset_j = j * box_data_length - with ib.if_scope(out[base_idx + offset_j] >= 0): - with ib.if_scope(num_valid_boxes[0] == max_output_size): - with ib.for_range(0, box_data_length) as k: - out[base_idx + offset_j + k] = -1.0 - box_indices[i * num_anchors + j] = -1 - with ib.else_scope(): - num_valid_boxes[0] += 1 + out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 + box_indices[i * num_anchors + j + valid_count[i]] = -1 + # Only return max_output_size number of valid boxes + num_valid_boxes[0] = 0 + with ib.if_scope(max_output_size > 0): + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(out[base_idx + offset_j] >= 0): + with ib.if_scope(num_valid_boxes[0] == max_output_size): + with ib.for_range(0, box_data_length) as k: + out[base_idx + offset_j + k] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): + num_valid_boxes[0] += 1 - if return_indices: - with ib.if_scope(j < valid_count[i]): - box_idx = box_indices[i * num_anchors + j] - with ib.if_scope(box_idx >= 0): - box_indices[i * num_anchors + j] = indices[i * num_anchors + box_idx] + if return_indices: + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = batch_size // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + i = bx * max_threads + tx + with ib.if_scope(i < batch_size): + with ib.for_range(0, valid_count[i]) as j: + idx = box_indices[i * num_anchors + j] + with ib.if_scope(idx >= 0): + box_indices[i * num_anchors + j] = indices[i * num_anchors + idx] return ib.get() @@ -486,11 +518,11 @@ def non_max_suppression( second dimension are like the output of arange(num_anchors) if get_valid_counts is not used before non_max_suppression. - max_output_size : optional, int + max_output_size : optional, tvm.te.Tensor or int Max number of output valid boxes for each instance. By default all valid boxes are returned. - iou_threshold : optional, float + iou_threshold : optional, tvm.te.Tensor or float Non-maximum suppression threshold. force_suppress : optional, boolean @@ -552,12 +584,7 @@ def non_max_suppression( score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) - target = tvm.target.Target.current() - if ( - target - and target.kind.name == "cuda" - and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True) - ): + if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True): sort_tensor = argsort_thrust( score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype ) @@ -570,6 +597,8 @@ def non_max_suppression( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 ) + indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) @@ -597,19 +626,19 @@ def non_max_suppression( name="nms", tag="nms", ) - if return_indices: - out_buf = tvm.tir.decl_buffer( - box_indices.shape, box_indices.dtype, "out_buf", data_alignment=8 - ) + out_shape = box_indices.shape + valid_box_count_shape = [box_indices.shape[0], 1] + valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count") + output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output") return te.extern( - [box_indices.shape, (batch_size, 1)], + [out_shape, valid_box_count_shape], [box_indices], lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]), - dtype=[box_indices.dtype, valid_count.dtype], - in_buffers=[out_buf], - name="rearrange_indices_out", - tag="rearrange_indices_out", + dtype="int32", + out_buffers=[output, valid_box_count], + name="rearrange_indices_out_gpu", + tag="rearrange_indices_out_gpu", ) return out diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 0094ef1adf11..329f0fb897e5 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -104,9 +104,9 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): nthread_bx = shape[axis] // max_threads + 1 tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("vthread") + bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "virtual_thread", nthread_bx) + ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local") if indices_out is not None: @@ -202,9 +202,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): nthread_tx = max_threads nthread_bx = size // max_threads + 1 tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("vthread") + bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "virtual_thread", nthread_bx) + ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index d32f07ff1d65..8e9cc625063b 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -115,21 +115,21 @@ class LambdaLifter : public ExprMutator { Array typed_captured_vars; Map rebinding_map; for (auto free_var : captured_vars) { - auto var = Var(free_var->name_hint(), free_var->checked_type()); - typed_captured_vars.push_back(var); - rebinding_map.Set(free_var, var); + auto var = Var(free_var->name_hint(), free_var->checked_type()); + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); } if (recursive) { if (!captured_vars.empty()) { - Array fvs; - for (auto fv : captured_vars) { - fvs.push_back(fv); - } - lambda_map_.emplace(letrec_.back(), Call(global, fvs)); - } else { - lambda_map_.emplace(letrec_.back(), global); + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); } + lambda_map_.emplace(letrec_.back(), Call(global, fvs)); + } else { + lambda_map_.emplace(letrec_.back(), global); + } } auto body = Downcast(ExprMutator::VisitExpr_(func_node)); @@ -173,10 +173,12 @@ class LambdaLifter : public ExprMutator { // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. auto before = Downcast(body)->params.size(); - auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, func->type_params, func->attrs, func->span); + auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, + func->type_params, func->attrs, func->span); auto after = Downcast(rebound_body)->params.size(); CHECK_EQ(before, after); - lifted_func = Function(typed_captured_vars, rebound_body, func->func_type_annotation(), free_type_vars); + lifted_func = + Function(typed_captured_vars, rebound_body, func->func_type_annotation(), free_type_vars); lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index bc6e40740da8..34aaf4689a59 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -44,8 +44,7 @@ template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] - ICHECK_EQ(types.size(), 2) - << "the arity of concatenate is 2, not " << types.size(); + ICHECK_EQ(types.size(), 2) << "the arity of concatenate is 2, not " << types.size(); /* If we receive a tuple we can continue, if we receive * anything but an incomplete type we should signal an * error. @@ -53,7 +52,7 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) + Diagnostic::Error(reporter->GetSpan()) << "concatenate requires a tuple of tensors as the first argument, found " << PrettyPrint(types[0])); return false; @@ -63,8 +62,8 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs const auto* param = attrs.as(); if (param == nullptr) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) << "the call attributes are not defined"); + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "the call attributes are not defined"); return false; } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index c95cdc8286f9..473b5d759272 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -261,19 +261,6 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In TVMRetValue rv; func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } -std::ostream& operator<<(std::ostream& out, const NDArray& array) { - static const PackedFunc* fprint = Registry::Get("relay._ndarray_repr"); - CHECK(fprint); - std::string data = (*fprint)(array); - out << data; - return out; -} - -void print(const ObjectRef& arg) { - if (arg->IsInstance()) { - std::cout << "\t" << Downcast(arg) << std::endl; - } -} void VirtualMachine::LoadExecutable(const Executable* exec) { ICHECK(exec) << "The executable is not created yet."; @@ -293,7 +280,6 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true); ICHECK(pf != nullptr) << "Cannot find function in module: " << packed_name; packed_funcs_[packed_index] = pf; - std::cout << packed_name << " -> " << packed_index < Date: Thu, 10 Dec 2020 14:02:43 -0700 Subject: [PATCH 08/10] better parallelize get_valid_counts --- python/tvm/topi/cuda/nms.py | 233 +++++++++++++++++++++++++++++------- 1 file changed, 187 insertions(+), 46 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 12ee057d040a..45cb4a1f43e2 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -116,24 +116,14 @@ def rearrange_indices_out_ir(data, output, valid_box_count): return ib.get() -def get_valid_counts_ir( - data, valid_count, out, out_indices, score_threshold, id_index, score_index -): - """Low level IR to get valid count of bounding boxes - given a score threshold. Also prepares to move valid boxes to the - top of input data. +def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): + """Low level IR to identify bounding boxes given a score threshold. Parameters ---------- data : Buffer Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. - valid_count : Buffer - 1D buffer for valid number of boxes with shape [batch_size, ]. - - flag : Buffer - 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - score_threshold : Buffer or float32 Lower limit of score for valid bounding boxes. @@ -145,8 +135,9 @@ def get_valid_counts_ir( Returns ------- - stmt : Stmt - The result IR statement. + valid_boxes: Buffer + 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. + """ batch_size = data.shape[0] num_anchors = data.shape[1] @@ -156,15 +147,70 @@ def get_valid_counts_ir( data = ib.buffer_ptr(data) - valid_count = ib.buffer_ptr(valid_count) - out = ib.buffer_ptr(out) - out_indices = ib.buffer_ptr(out_indices) - one = tvm.tir.const(1, dtype=out.dtype) + valid_boxes = ib.buffer_ptr(valid_boxes) if isinstance(score_threshold, float): score_threshold = tvm.tir.FloatImm("float32", score_threshold) id_index = tvm.tir.IntImm("int32", id_index) score_index = tvm.tir.IntImm("int32", score_index) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + tid = bx * max_threads + tx + + with ib.if_scope(tid < num_anchors): + i = by + j = tid + score = data[(i * num_anchors + j) * elem_length + score_index] + with ib.if_scope( + tvm.tir.all( + score > score_threshold, + tvm.tir.any( + id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0 + ), + ) + ): + valid_boxes[i * num_anchors + j] = 1 + with ib.else_scope(): + valid_boxes[i * num_anchors + j] = 0 + return ib.get() + + +def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): + """Low level IR to get the ouput indices of valid boxes + and the count of valid boxes + + Parameters + ---------- + valid_boxes: Buffer + 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. + + Returns + ------- + valid_count: Buffer + 1D Buffer of number of valid boxes per batch [batch_size]. + + valid_indices: Buffer + 2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors]. + """ + batch_size = valid_boxes.shape[0] + num_anchors = valid_boxes.shape[1] + + ib = tvm.tir.ir_builder.create() + + valid_boxes = ib.buffer_ptr(valid_boxes) + + valid_count = ib.buffer_ptr(valid_count) + valid_indices = ib.buffer_ptr(valid_indices) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads @@ -174,29 +220,96 @@ def get_valid_counts_ir( ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx + # TODO(mbrookhart): Parallelize the sum and cumsum here + current_index = ib.allocate("int32", (1,), name="current_index", scope="local") with ib.if_scope(tid < batch_size): + current_index[0] = 0 valid_count[tid] = 0 - i = tid with ib.for_range(0, num_anchors) as j: - score = data[(i * num_anchors + j) * elem_length + score_index] - with ib.if_scope( - tvm.tir.all( - score > score_threshold, - tvm.tir.any( - id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0 - ), - ) - ): - with ib.for_range(0, elem_length) as k: - out[(i * num_anchors + valid_count[i]) * elem_length + k] = data[ - (i * num_anchors + j) * elem_length + k - ] - out_indices[i * num_anchors + valid_count[i]] = j - valid_count[i] += 1 - with ib.if_scope(j >= valid_count[i]): - with ib.for_range(0, elem_length) as k: - out[(i * num_anchors + j) * elem_length + k] = -one - out_indices[i * num_anchors + j] = -1 + idx = tid * num_anchors + j + valid_count[tid] = valid_count[tid] + valid_boxes[idx] + with ib.if_scope(valid_boxes[idx] == 1): + valid_indices[idx] = current_index[0] + current_index[0] = current_index[0] + 1 + with ib.else_scope(): + valid_indices[idx] = -1 + return ib.get() + + +def get_valid_counts_ir(data, valid_indices, out, out_indices): + """Low level IR to get valid count of bounding boxes + given a score threshold. Also prepares to move valid boxes to the + top of input data. + + Parameters + ---------- + data : Buffer + Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. + + valid_indices: Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. + + Returns + ------- + out : Buffer + Sorted valid boxes + + out_indices : Buffer + Incidices of valid boxes in original data + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + + valid_indices = ib.buffer_ptr(valid_indices) + out = ib.buffer_ptr(out) + out_indices = ib.buffer_ptr(out_indices) + one = tvm.tir.const(1, dtype=out.dtype) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + nthread_bz = elem_length + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_anchors): + i = by + j = tid + k = bz + out[(i * num_anchors + j) * elem_length + k] = -one + out_indices[i * num_anchors + j] = -1 + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_anchors): + i = by + j = tid + k = bz + with ib.if_scope(valid_indices[i, tid] >= 0): + out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[ + (i * num_anchors + j) * elem_length + k + ] + out_indices[i * num_anchors + valid_indices[i, tid]] = j return ib.get() @@ -229,23 +342,51 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): batch_size = data.shape[0] num_anchors = data.shape[1] data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + valid_boxes_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "valid_boxes_buf", data_alignment=8 + ) + valid_boxes = te.extern( + [(batch_size, num_anchors)], + [data], + lambda ins, outs: get_valid_boxes_ir( + ins[0], outs[0], score_threshold, id_index, score_index + ), + dtype=["int32"], + in_buffers=[data_buf], + out_buffers=[valid_boxes_buf], + name="get_valid_boxes", + tag="get_valid_boxes_gpu", + ) + + valid_indices_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8 + ) valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8 ) + valid_count, valid_indices = te.extern( + [(batch_size,), (batch_size, num_anchors)], + [valid_boxes], + lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]), + dtype=["int32"], + in_buffers=[valid_boxes_buf], + out_buffers=[valid_count_buf, valid_indices_buf], + name="get_valid_indices", + tag="get_valid_indices_gpu", + ) + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out_indices_buf = tvm.tir.decl_buffer( (batch_size, num_anchors), "int32", "out_buf", data_alignment=8 ) - valid_count, out, out_indices = te.extern( - [(batch_size,), data.shape, (batch_size, num_anchors)], - [data], - lambda ins, outs: get_valid_counts_ir( - ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index - ), + out, out_indices = te.extern( + [data.shape, (batch_size, num_anchors)], + [data, valid_indices], + lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]), dtype=["int32", data.dtype], - in_buffers=[data_buf], - out_buffers=[valid_count_buf, out_buf, out_indices_buf], + in_buffers=[data_buf, valid_indices_buf], + out_buffers=[out_buf, out_indices_buf], name="get_valid_counts", tag="get_valid_counts_gpu", ) From 1bb57bd4b6318e1420cb728a2b38eaf37c417916 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 10 Dec 2020 17:21:37 -0700 Subject: [PATCH 09/10] improve nms parallelization --- python/tvm/topi/cuda/nms.py | 180 +++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 75 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 45cb4a1f43e2..fa2a38e4a8d2 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -512,7 +512,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): indices = ib.buffer_ptr(indices) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) - num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) @@ -525,86 +524,117 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", 1) - - with ib.for_range(0, batch_size) as i: - base_idx = i * num_anchors * box_data_length - with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - # Reorder output - nkeep = if_then_else( - tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] - ) - with ib.for_range(0, nkeep) as j: + nthread_by = batch_size + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + i = by + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + # Reorder output + nkeep = if_then_else( + tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] + ) + with ib.for_range(0, nkeep) as j: + with ib.for_range(0, box_data_length) as k: + out[(base_idx + j * box_data_length + k)] = data[ + (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) + ] + box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] + with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])): + with ib.for_range(0, valid_count[i] - nkeep) as j: with ib.for_range(0, box_data_length) as k: - out[(base_idx + j * box_data_length + k)] = data[ - (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) - ] - box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] - with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])): - with ib.for_range(0, valid_count[i] - nkeep) as j: - with ib.for_range(0, box_data_length) as k: - out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 - box_indices[i * num_anchors + (j + nkeep)] = -1 - # Apply nms - with ib.for_range(0, valid_count[i]) as j: - with ib.for_range(0, j) as k: - offset_k = k * box_data_length + out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + (j + nkeep)] = -1 + with ib.new_scope(): + nthread_by = batch_size + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + i = by + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + # Apply nms + with ib.for_range(0, valid_count[i]) as j: + with ib.for_range(0, j) as k: + offset_k = k * box_data_length + with ib.if_scope( + tvm.tir.all( + out[base_idx + offset_k + score_index] > 0, + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + ) + ): + offset_j = j * box_data_length with ib.if_scope( tvm.tir.all( + j > k, out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] + == out[base_idx + offset_j + id_index], + ), ) ): - offset_j = j * box_data_length - with ib.if_scope( - tvm.tir.all( - j > k, - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any( - id_index < 0, out[base_idx + offset_j + id_index] >= 0 - ), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], - ), - ) - ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_j + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_j + id_index] = -1.0 - box_indices[i * num_anchors + j] = -1 - with ib.else_scope(): - with ib.for_range(0, valid_count[i]) as j: - offset_j = j * box_data_length - with ib.for_range(0, box_data_length) as k: - out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] - box_indices[i * num_anchors + j] = j - # Set invalid entry to be -1 - with ib.for_range(0, num_anchors - valid_count[i]) as j: - with ib.for_range(0, box_data_length) as k: - out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 - box_indices[i * num_anchors + j + valid_count[i]] = -1 - # Only return max_output_size number of valid boxes - num_valid_boxes[0] = 0 - with ib.if_scope(max_output_size > 0): - with ib.for_range(0, valid_count[i]) as j: - offset_j = j * box_data_length - with ib.if_scope(out[base_idx + offset_j] >= 0): - with ib.if_scope(num_valid_boxes[0] == max_output_size): - with ib.for_range(0, box_data_length) as k: - out[base_idx + offset_j + k] = -1.0 - box_indices[i * num_anchors + j] = -1 - with ib.else_scope(): - num_valid_boxes[0] += 1 + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_j + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_j + id_index] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + nthread_bz = box_data_length + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + tid = bx * max_threads + tx + i = by + j = tid + k = bz + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + pass + with ib.else_scope(): + with ib.if_scope(j < valid_count[i]): + offset_j = j * box_data_length + out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] + box_indices[i * num_anchors + j] = j + + with ib.new_scope(): + num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + i = bx + base_idx = i * num_anchors * box_data_length + # Set invalid entry to be -1 + with ib.for_range(0, num_anchors - valid_count[i]) as j: + with ib.for_range(0, box_data_length) as k: + out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 + box_indices[i * num_anchors + j + valid_count[i]] = -1 + # Only return max_output_size number of valid boxes + num_valid_boxes[0] = 0 + with ib.if_scope(max_output_size > 0): + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(out[base_idx + offset_j] >= 0): + with ib.if_scope(num_valid_boxes[0] == max_output_size): + with ib.for_range(0, box_data_length) as k: + out[base_idx + offset_j + k] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): + num_valid_boxes[0] += 1 if return_indices: with ib.new_scope(): From e027ecd656f5dd2f437cfbe4d75c4bbe583cd24a Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Sat, 12 Dec 2020 19:03:08 -0700 Subject: [PATCH 10/10] respond to cuda/thrust enablement issue --- python/tvm/topi/cuda/nms.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index fa2a38e4a8d2..d0915d9aa55f 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -755,7 +755,12 @@ def non_max_suppression( score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) - if tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True): + target = tvm.target.Target.current() + if ( + target + and target.kind.name == "cuda" + and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True) + ): sort_tensor = argsort_thrust( score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype )