diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 9bdac71b6ee4..5de5d0a067b0 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -164,6 +164,14 @@ This level enables additional math and transform operators. tvm.relay.vision.yolo_reorg +**Level 6: Algorithm Operators** + +.. autosummary:: + :nosignatures: + + tvm.relay.argsort + + **Level 10: Temporary Operators** This level support backpropagation of broadcast operators. It is temporary. @@ -292,6 +300,11 @@ Level 5 Definitions .. autofunction:: tvm.relay.vision.yolo_reorg +Level 6 Definitions +------------------- +.. autofunction:: tvm.relay.argsort + + Level 10 Definitions -------------------- .. autofunction:: tvm.relay.broadcast_to_like diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h new file mode 100644 index 000000000000..20f135c11bba --- /dev/null +++ b/include/tvm/relay/attrs/algorithm.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/vision.h + * \brief Auxiliary attributes for vision operators. + */ +#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_ +#define TVM_RELAY_ATTRS_ALGORITHM_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Attributes used in argsort operators */ +struct ArgsortAttrs : public tvm::AttrsNode { + int axis; + bool is_ascend; + DataType dtype; + + TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") { + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("Axis along which to sort the input tensor." + "If not given, the flattened array is used."); + TVM_ATTR_FIELD(is_ascend).set_default(true) + .describe("Whether to sort in ascending or descending order." + "By default, sort in ascending order"); + TVM_ATTR_FIELD(dtype).set_default(NullValue()) + .describe("DType of the output indices."); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 2b3eb4f32b45..11b4ebfcfaad 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -92,6 +92,8 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode& lhs, } -// Argsort implemented C library sort. +// Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. // sort_num specify the number of elements to be sorted. // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") .set_body([](TVMArgs args, TVMRetValue *ret) { DLTensor *input = args[0]; DLTensor *sort_num = args[1]; DLTensor *output = args[2]; int32_t axis = args[3]; - bool is_descend = args[4]; + bool is_ascend = args[4]; auto dtype = input->dtype; auto data_ptr = static_cast(input->data); @@ -97,10 +97,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") int64_t full_idx = base_idx + k * axis_mul_after; sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } - if (is_descend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); - } else { + if (is_ascend) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + } else { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); } for (int32_t k = 0; k < input->shape[axis]; ++k) { *(static_cast(output->data) + base_idx + k * axis_mul_after) @@ -110,5 +110,68 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } }); + +// Argsort implemented C library sort. +// Return indices of sorted tensor. +// By default, the last axis will be used to sort. +// sort_num specify the number of elements to be sorted. +// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) +// and sort axis is dk. sort_num should have dimension of +// (d1, d2, ..., d(k-1), d(k+1), ..., dn). +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor *input = args[0]; + DLTensor *output = args[1]; + int32_t axis = args[2]; + bool is_ascend = args[3]; + + auto dtype = input->dtype; + auto data_ptr = static_cast(input->data); + std::vector> sorter; + int64_t axis_mul_before = 1; + int64_t axis_mul_after = 1; + + if (axis < 0) { + axis = input->ndim + axis; + } + + // Currently only supports input dtype to be float32. + CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " + "to be float32."; + CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " + "to be float32."; + CHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " << input->ndim; + + for (int i = 0; i < input->ndim; ++i) { + if (i < axis) { + axis_mul_before *= input->shape[i]; + } else if (i > axis) { + axis_mul_after *= input->shape[i]; + } + } + + int32_t current_sort_num = input->shape[axis]; + for (int64_t i = 0 ; i < axis_mul_before; ++i) { + for (int64_t j = 0 ; j < axis_mul_after; ++j) { + sorter.clear(); + int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; + for (int64_t k = 0; k < current_sort_num; ++k) { + int64_t full_idx = base_idx + k * axis_mul_after; + sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); + } + if (is_ascend) { + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + } else { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + } + for (int32_t k = 0; k < input->shape[axis]; ++k) { + *(static_cast(output->data) + base_idx + k * axis_mul_after) + = k < static_cast(sorter.size()) ? sorter[k].first : k; + } + } + } +}); + } // namespace contrib } // namespace tvm diff --git a/src/relay/op/algorithm/sort.cc b/src/relay/op/algorithm/sort.cc new file mode 100644 index 000000000000..5777b79699b1 --- /dev/null +++ b/src/relay/op/algorithm/sort.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file nms.cc + * \brief Non-maximum suppression operators + */ +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(ArgsortAttrs); + +bool ArgsortRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + const ArgsortAttrs* param = attrs.as(); + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "Argsort: expect input type to be TensorType but get " + << types[0]; + return false; + } + CHECK_EQ(param->dtype, Float(32)); + reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype)); + return true; +} + +Expr MakeArgsort(Expr data, + int axis, + bool is_ascend, + DataType dtype) { + auto attrs = make_node(); + attrs->axis = axis; + attrs->is_ascend = is_ascend; + attrs->dtype = dtype; + static const Op& op = Op::Get("argsort"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op._make.argsort") +.set_body_typed(MakeArgsort); + +RELAY_REGISTER_OP("argsort") +.describe(R"doc(Returns the indices that would sort an +input array along the given axis. +)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ArgsortAttrs") +.add_argument("data", "Tensor", "Input data.") +.set_support_level(6) +.add_type_rel("Argsort", ArgsortRel); +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 5344bce3d641..2e5661cdc4dc 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -106,6 +106,8 @@ Expr MakeNMS(Expr data, double iou_threshold, bool force_suppress, int top_k, + int coord_start, + int score_index, int id_index, bool return_indices, bool invalid_to_bottom) { @@ -114,6 +116,8 @@ Expr MakeNMS(Expr data, attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; attrs->top_k = top_k; + attrs->coord_start = coord_start; + attrs->score_index = score_index; attrs->id_index = id_index; attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index 856d3fa9cf83..87cdac01ce3a 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -24,11 +24,11 @@ def test_sort(): data = tvm.placeholder((n, l, m), name='data') sort_num = tvm.placeholder((n, m), name="sort_num", dtype="int32") axis = 1 - is_descend = True + is_ascend = False out = tvm.extern(data.shape, [data, sort_num], lambda ins, outs: tvm.call_packed( - "tvm.contrib.sort.argsort", ins[0], - ins[1], outs[0], axis, is_descend), + "tvm.contrib.sort.argsort_nms", ins[0], + ins[1], outs[0], axis, is_ascend), dtype='int32', name="sort_tensor") input = [[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]], [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]] @@ -50,13 +50,13 @@ def test_sort_np(): dshape = (1, 2, 3, 4, 5, 6) axis = 4 reduced_shape = (1, 2, 3, 4, 6) - is_descend = False + is_ascend = True data = tvm.placeholder(dshape, name='data') sort_num = tvm.placeholder(reduced_shape, name="sort_num", dtype="int32") out = tvm.extern(data.shape, [data, sort_num], lambda ins, outs: tvm.call_packed( - "tvm.contrib.sort.argsort", ins[0], - ins[1], outs[0], axis, is_descend), + "tvm.contrib.sort.argsort_nms", ins[0], + ins[1], outs[0], axis, is_ascend), dtype='int32', name="sort_tensor") ctx = tvm.cpu(0) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 7e1c37169978..e6d99c765c87 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -177,12 +177,13 @@ def verify_get_valid_counts(dshape, score_threshold): assert "score_threshold" in z.astext() func = relay.Function([x], z.astuple()) func = relay.ir_pass.infer_type(func) - ctx_list = [("llvm", tvm.cpu(0))] - for target, ctx in ctx_list: + for target, ctx in ctx_list(): + if target == 'cuda': + return intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) - tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3) + tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) + tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) verify_get_valid_counts((1, 2500, 6), 0) verify_get_valid_counts((1, 2500, 6), -1) @@ -195,9 +196,13 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, iou_threshold=0.5, force_suppress=False, top_k=-1, check_type_only=False): x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) - x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int")) - z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False) - z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k) + x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32")) + z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ + iou_threshold = iou_threshold, force_suppress = force_suppress, \ + top_k = top_k, return_indices=False) + z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ + iou_threshold = iou_threshold, force_suppress = force_suppress, \ + top_k = top_k) assert "iou_threshold" in z.astext() assert "iou_threshold" in z_indices.astext() zz = relay.ir_pass.infer_type(z) @@ -212,8 +217,7 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, func = relay.ir_pass.infer_type(func) func_indices = relay.Function([x0, x1], z_indices) func_indices = relay.ir_pass.infer_type(func_indices) - ctx_list = [("llvm", tvm.cpu(0))] - for target, ctx in ctx_list: + for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x0_data, x1_data) op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data) @@ -296,8 +300,7 @@ def test_default_value(): nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False) func = relay.Function([cls_prob, loc_pred, anchors], nms) func = relay.ir_pass.infer_type(func) - ctx_list = [("llvm", tvm.cpu(0))] - for target, ctx in ctx_list: + for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py new file mode 100644 index 000000000000..983a9154df34 --- /dev/null +++ b/tests/python/relay/test_op_level6.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Support level6 operator test cases. +""" +import math +import numpy as np +import tvm +from tvm import relay +from tvm.relay.testing import ctx_list +import topi.testing + +def test_argsort(): + def verify_argsort(shape, axis, is_ascend): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.argsort(x, axis=axis, is_ascend=is_ascend) + zz = relay.ir_pass.infer_type(z) + func = relay.Function([x], z) + x_data = np.random.uniform(size=shape).astype("float32") + if is_ascend: + ref_res = np.argsort(x_data, axis=axis) + else: + ref_res = np.argsort(-x_data, axis=axis) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5) + verify_argsort((2, 3, 4), axis=0, is_ascend=False) + verify_argsort((1, 4, 6), axis=1, is_ascend=True) + verify_argsort((3, 5, 6), axis=-1, is_ascend=False) + + +if __name__ == "__main__": + test_argsort() diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index 2eb460d151ae..a9984148d5d3 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -21,6 +21,7 @@ from .reduction import * from .transform import * from .broadcast import * +from .sort import * from . import nn from . import x86 from . import cuda diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index e6377fa40c52..5d04d72a7eca 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -20,77 +20,380 @@ import tvm from tvm import api -from topi.vision import non_max_suppression -from ..util import get_const_tuple +from tvm.generic import cast +from tvm.intrin import if_then_else, log, power +from topi.vision import non_max_suppression, get_valid_counts +from .sort import argsort -def sort_ir(data, index, output): - """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + +def get_valid_counts_pre(data, flag, idx, score_threshold): + """Low level IR to Prepare get valid count of bounding boxes + given a score threshold. Also moves valid boxes to the + top of input data. Parameters ---------- data: Buffer - 2D Buffer of input boxes' score with shape [batch_size, num_anchors]. + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + flag : Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - index : Buffer - 1D Buffer of number of valid number of boxes. + idx : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. - output : Buffer - 2D Output buffer of indicies of sorted tensor with shape [batch_size, num_anchors]. + score_threshold : float32 + Lower limit of score for valid bounding boxes. Returns ------- stmt : Stmt The result IR statement. """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) - assert data.dtype == "float32", "Currently only supports input dtype to be float32" - batch, num_anchors = get_const_tuple(data.shape) max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = batch_size * num_anchors // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(tid < batch_size * num_anchors): + with ib.if_scope(data[tid * box_data_length + 1] > score_threshold): + flag[tid] = 1 + idx[tid] = 1 + with ib.else_scope(): + flag[tid] = 0 + idx[tid] = 0 + + return ib.get() + +def get_valid_counts_upsweep(data, idx_in, idx, partial): + """Low level IR of first step of scan: unsweep. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + idx_in : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + idx : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + partial : Buffer + 2D Buffer of valid data indices with shape [batch_size, new_range]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] ib = tvm.ir_builder.create() - p_data = ib.buffer_ptr(data) - p_index = ib.buffer_ptr(index) - p_out = ib.buffer_ptr(output) + data = ib.buffer_ptr(data) + idx_in = ib.buffer_ptr(idx_in) + idx = ib.buffer_ptr(idx) + partial = ib.buffer_ptr(partial) + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + elem_per_thread = num_anchors // max_threads + 1 nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 + nthread_bx = batch_size tx = tvm.thread_axis("threadIdx.x") - bx = tvm.thread_axis("vthread") + bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "virtual_thread", nthread_bx) - tid = bx * nthread_tx + tx - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") - - with ib.for_range(0, batch, for_type="unroll") as b: - start = b * num_anchors - with ib.if_scope(tid < num_anchors): - p_out[start + tid] = tid - # OddEvenTransposeSort - with ib.for_range(0, p_index[b]) as k: - with ib.if_scope(tid < (p_index[b] + 1) // 2): - offset = start + 2 * tid + (k % 2) - with ib.if_scope( \ - tvm.all(offset + 1 < p_index[0], p_data[offset] < p_data[offset + 1])): - temp_data[0] = p_data[offset] - p_data[offset] = p_data[offset + 1] - p_data[offset + 1] = temp_data[0] - temp_index[0] = p_out[offset] - p_out[offset] = p_out[offset + 1] - p_out[offset + 1] = temp_index[0] + ib.scope_attr(bx, "thread_extent", nthread_bx) + new_range = num_anchors // elem_per_thread + 1 + # Scan: Upsweep: + with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)): + with ib.for_range(0, elem_per_thread) as i: + with ib.if_scope(bx * num_anchors + \ + tx * elem_per_thread + i < batch_size * num_anchors): + with ib.if_scope(i == 0): + partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread] + idx[bx * num_anchors + tx * elem_per_thread] = \ + idx_in[bx * num_anchors + tx * elem_per_thread] + with ib.else_scope(): + partial[bx * new_range + tx] += \ + idx_in[bx * num_anchors + tx * elem_per_thread + i] + idx[bx * num_anchors + tx * elem_per_thread + i] = \ + idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \ + idx_in[bx * num_anchors + tx * elem_per_thread + i] + return ib.get() + +def get_valid_counts_scan(data, partial_in, partial): + """Low level IR to do scan. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + idx_in : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + idx : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + partial : Buffer + 2D Buffer of valid data indices with shape [batch_size, new_range]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + ib = tvm.ir_builder.create() + partial_in = ib.buffer_ptr(partial_in) + partial = ib.buffer_ptr(partial) + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + elem_per_thread = num_anchors // max_threads + 1 + nthread_tx = max_threads + nthread_bx = batch_size + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + var = tvm.make.node("FloatImm", dtype="float32", value=2) + new_range = num_anchors // elem_per_thread + 1 + iteration = log(cast(new_range, "float32")) // math.log(2) + # Scan: Kogge-Stone adder + with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): + with ib.for_range(0, iteration) as k: + with ib.if_scope(k == 0): + with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))): + partial[bx * new_range + tx] = \ + partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1] + with ib.else_scope(): + partial[bx * new_range] = partial_in[bx * new_range] + with ib.else_scope(): + with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \ + tx < tvm.min(new_range, num_anchors))): + partial[bx * new_range + tx] += \ + partial[bx * new_range + tx - cast(power(var, k), "int32")] ib.emit(tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) + return ib.get() + +def get_valid_counts_downsweep(data, idx_in, partial, idx): + """Low level IR to do downsweep of scan. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + idx_in : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + partial : Buffer + 2D Buffer of valid data indices with shape [batch_size, new_range]. + + idx : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + ib = tvm.ir_builder.create() + idx_in = ib.buffer_ptr(idx_in) + idx = ib.buffer_ptr(idx) + partial = ib.buffer_ptr(partial) + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + elem_per_thread = num_anchors // max_threads + 1 + nthread_tx = max_threads + nthread_bx = batch_size * num_anchors // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + new_range = num_anchors // elem_per_thread + 1 + # Scan: Downsweep: + with ib. if_scope(tid < batch_size * num_anchors): + i = tid / num_anchors # number of batches + j = tid % num_anchors # number of anchors + with ib.if_scope(j < elem_per_thread): + idx[tid] = idx_in[tid] + with ib.else_scope(): + idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1] return ib.get() -def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): +def get_valid_counts_ir(data, flag, idx, valid_count, out): + """Low level IR to get valid count of bounding boxes + given a score threshold. Also moves valid boxes to the + top of input data. + + Parameters + ---------- + data : Buffer + Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. + + flag : Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. + + idx : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + valid_count : Buffer + 1-D buffer for valid number of boxes. + + out : Buffer + Rearranged data buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + size = batch_size * num_anchors * elem_length + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + valid_count = ib.buffer_ptr(valid_count) + out = ib.buffer_ptr(out) + + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(tid < batch_size * num_anchors): + i = tid / num_anchors + j = tid % num_anchors + base_idx = i * num_anchors * elem_length + with ib.if_scope(flag[tid] > 0): + with ib.for_range(0, elem_length) as k: + with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size): + out[base_idx + (idx[tid] - 1) * elem_length + k] =\ + data[base_idx + j * elem_length + k] + with ib.if_scope(j == 0): + valid_count[i] = idx[tid + num_anchors - 1] + with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]): + with ib.for_range(0, elem_length) as l: + with ib.if_scope(tid * elem_length + l < size): + out[tid * elem_length + l] = -1.0 + return ib.get() + + +@get_valid_counts.register(["cuda", "gpu"]) +def get_valid_counts_gpu(data, score_threshold=0): + """Get valid count of bounding boxes given a score threshold. + Also moves valid boxes to the top of input data. + + Parameters + ---------- + data : tvm.Tensor + Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. + + score_threshold : optional, float + Lower limit of score for valid bounding boxes. + + Returns + ------- + valid_count : tvm.Tensor + 1-D tensor for valid number of boxes. + + out_tensor : tvm.Tensor + Rearranged data tensor. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + elem_per_thread = num_anchors // max_threads + 1 + new_range = num_anchors // elem_per_thread + 1 + temp_flag_buf = api.decl_buffer( + (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8) + temp_idx_buf = api.decl_buffer( + (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8) + temp_partial_buf = api.decl_buffer( + (batch_size, new_range), "int32", "temp_partial", data_alignment=8) + data_buf = api.decl_buffer( + data.shape, data.dtype, "data_buf", data_alignment=8) + + temp_flag, temp_idx = \ + tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], + lambda ins, outs: get_valid_counts_pre( + ins[0], outs[0], outs[1], score_threshold), + dtype=["int32", "int32"], + out_buffers=[temp_flag_buf, temp_idx_buf], + name="get_valid_counts_phase_one") + temp_idx_new, temp_partial = \ + tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx], + lambda ins, outs: get_valid_counts_upsweep( + ins[0], ins[1], outs[0], outs[1]), + dtype=["int32", "int32"], + out_buffers=[temp_idx_buf, temp_partial_buf], + name="get_valid_counts_phase_two") + temp_partial_new = \ + tvm.extern([(batch_size, new_range)], [data, temp_partial], + lambda ins, outs: get_valid_counts_scan( + ins[0], ins[1], outs[0]), + dtype=["int32"], + out_buffers=[temp_partial_buf], + name="get_valid_counts_phase_three") + temp_idx_final = \ + tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new], + lambda ins, outs: get_valid_counts_downsweep( + ins[0], ins[1], ins[2], outs[0]), + dtype=["int32"], + out_buffers=[temp_idx_buf], + name="get_valid_counts_phase_four") + valid_count, out_tensor = \ + tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final], + lambda ins, outs: get_valid_counts_ir( + ins[0], ins[1], ins[2], outs[0], outs[1]), + dtype=["int32", data.dtype], + in_buffers=[data_buf, temp_flag_buf, temp_idx_buf], + name="get_valid_counts_phase_five", + tag="get_valid_counts_gpu") + + return [valid_count, out_tensor] + + +def nms_ir(data, sorted_index, valid_count, out, box_indices, + max_output_size, iou_threshold, force_suppress, + top_k, coord_start, id_index): """Low level IR routing for transform location in multibox_detection operator. Parameters ---------- - data: Buffer + data : Buffer Buffer of output boxes with class and score. - sort_result : Buffer + sort_index : Buffer Buffer of output box indexes sorted by score. valid_count : Buffer @@ -99,15 +402,25 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n out : Buffer Output buffer. - nms_threshold : float - Non-maximum suppression threshold. + max_output_size : int + Max number of output valid boxes for each instance. + By default all valid boxes are returned. + + iou_threshold : float + Overlapping(IoU) threshold to suppress object with smaller score. force_suppress : boolean Whether to suppress all detections regardless of class_id. - nms_topk : int + top_k : int Keep maximum top k detections before nms, -1 for no limit. + coord_start : int + Start index of the consecutive 4 coordinates. + + id_index : int + index of the class categories, -1 to disable. + Returns ------- stmt : Stmt @@ -127,100 +440,232 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.expr.Select(u <= 0.0, 0.0, i / u) + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + sorted_index = ib.buffer_ptr(sorted_index) + valid_count = ib.buffer_ptr(valid_count) + out = ib.buffer_ptr(out) + box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") + max_threads = int(math.sqrt( tvm.target.current_target(allow_none=False).max_num_threads)) - ib = tvm.ir_builder.create() - p_data = ib.buffer_ptr(data) - p_sort_result = ib.buffer_ptr(sort_result) - p_valid_count = ib.buffer_ptr(valid_count) - p_out = ib.buffer_ptr(out) - batch_size = out.shape[0] - num_anchors = out.shape[1] nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) - i = bx * max_threads + tx - - nms_threshold_node = tvm.make.node( - "FloatImm", dtype="float32", value=nms_threshold) - nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk) - force_suppress_node = tvm.make.node( - "IntImm", dtype="int32", value=1 if force_suppress else 0) - with ib.for_range(0, batch_size, for_type="unroll") as b: - base_idx = b * num_anchors * 6 - with ib.if_scope( \ - tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, - p_valid_count[0] > 0)): + k = bx * max_threads + tx + + iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold) + top_k = tvm.make.node("IntImm", dtype="int32", value=top_k) + coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start) + id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) + force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)): # Reorder output - nkeep = tvm.if_then_else( \ - tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b]), - nms_topk, p_valid_count[b]) - with ib.for_range(0, nkeep) as l: - with ib.if_scope(i < 6): - p_out[(base_idx + l * 6 + i)] = \ - p_data[(base_idx + p_sort_result[b * num_anchors + l] * 6 + i)] - with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])): - with ib.for_range(0, p_valid_count[b] - nkeep) as l: - with ib.if_scope(i < 6): - p_out[(base_idx + (l + nkeep) * 6 + i)] = -1.0 + nkeep = if_then_else( \ + tvm.all(top_k > 0, top_k < valid_count[i]), + top_k, valid_count[i]) + with ib.for_range(0, nkeep) as j: + with ib.if_scope(k < box_data_length): + out[(base_idx + j * box_data_length + k)] = \ + data[(base_idx + sorted_index[i * num_anchors + j] \ + * box_data_length + k)] + box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] + with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])): + with ib.for_range(0, valid_count[i] - nkeep) as j: + with ib.if_scope(k < box_data_length): + out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + (j + nkeep)] = -1 # Apply nms - with ib.for_range(0, p_valid_count[b]) as l: - offset_l = l * 6 - with ib.if_scope(p_out[base_idx + offset_l] >= 0): - with ib.if_scope(i < p_valid_count[b]): - offset_i = i * 6 - with ib.if_scope(tvm.all(i > l, p_out[base_idx - + offset_i] >= 0)): - with ib.if_scope(tvm.any(force_suppress_node > 0, - p_out[base_idx + offset_l] == - p_out[base_idx + offset_i])): - # When force_suppress == True or class_id equals - iou = calculate_overlap(p_out, base_idx + offset_l + 2, - base_idx + offset_i + 2) - with ib.if_scope(iou >= nms_threshold): - p_out[base_idx + offset_i] = -1.0 + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(out[base_idx + offset_j] >= 0): + with ib.if_scope(k < valid_count[i]): + offset_k = k * box_data_length + with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \ + tvm.any(force_suppress > 0, id_index < 0, \ + out[base_idx + offset_j] == \ + out[base_idx + offset_k]))): + iou = calculate_overlap(out, base_idx + offset_k + coord_start, + base_idx + offset_j + coord_start) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_k] = -1.0 + box_indices[i * num_anchors + k] = -1 ib.emit(tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) with ib.else_scope(): - with ib.for_range(0, p_valid_count[b]) as c: - with ib.if_scope(i < 6): - p_out[(base_idx + c * 6 + i)] = p_data[base_idx + c * 6 + i] + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(k < box_data_length): + out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] + box_indices[i * num_anchors + j] = j # Set invalid entry to be -1 - with ib.for_range(0, num_anchors - p_valid_count[b]) as c: - with ib.if_scope(i < 6): - p_out[base_idx + (c + p_valid_count[b]) * 6 + i] = -1.0 - body = ib.get() - return body + with ib.for_range(0, num_anchors - valid_count[i]) as j: + with ib.if_scope(k < box_data_length): + out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 + box_indices[i * num_anchors + j + valid_count[i]] = -1 + # Only return max_output_size number of valid boxes + num_valid_boxes[0] = 0 + with ib.if_scope(max_output_size > 0): + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(out[base_idx + offset_j] >= 0): + with ib.if_scope(num_valid_boxes[0] == max_output_size): + with ib.if_scope(k < box_data_length): + out[base_idx + offset_j + k] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): + num_valid_boxes[0] += 1 + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) + + return ib.get() + + +def invalid_to_bottom_pre(data, flag, idx): + """Low level IR to rearrange nms output to move all valid entries to top. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + flag : Buffer + 1D Buffer of flag indicating valid data with [num_anchors]. + + idx : Buffer + 1D Buffer of valid data indices with [num_anchors]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + + max_threads = int(math.sqrt( + tvm.target.current_target(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + j = bx * max_threads + tx + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * elem_length + with ib.if_scope(j < num_anchors): + with ib.if_scope(data[base_idx + j * elem_length] >= 0): + flag[i * num_anchors + j] = 1 + idx[i * num_anchors + j] = 1 + with ib.else_scope(): + flag[i * num_anchors + j] = 0 + idx[i * num_anchors + j] = 0 + + with ib.if_scope(j < batch_size): + with ib.for_range(0, num_anchors) as k: + with ib.if_scope(k > 0): + idx[j * num_anchors + k] += idx[j * num_anchors + k - 1] + return ib.get() + + +def invalid_to_bottom_ir(data, flag, idx, out): + """Low level IR to rearrange nms output to move all valid entries to top. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. + + flag : Buffer + 1D Buffer of flag indicating valid data with [num_anchors]. + + idx : Buffer + 1D Buffer of valid data indices with [num_anchors]. + + out : Buffer + 3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + out = ib.buffer_ptr(out) + + max_threads = int(math.sqrt( + tvm.target.current_target(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + j = bx * max_threads + tx + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * elem_length + with ib.if_scope(j < num_anchors): + with ib.for_range(0, elem_length) as k: + out[base_idx + j * elem_length + k] = -1.0 + with ib.if_scope(flag[i * num_anchors + j] > 0): + with ib.for_range(0, elem_length) as k: + out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \ + = data[base_idx + j * elem_length + k] + return ib.get() @non_max_suppression.register(["cuda", "gpu"]) -def nms_gpu(data, - valid_count, - max_output_size=-1, - iou_threshold=0.5, - force_suppress=False, - top_k=-1, - id_index=0, - return_indices=True, - invalid_to_bottom=False): +def non_max_suppression_gpu(data, valid_count, max_output_size=-1, + iou_threshold=0.5, force_suppress=False, top_k=-1, + coord_start=2, score_index=1, id_index=0, + return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters ---------- data : tvm.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. + 3-D tensor with shape [batch_size, num_anchors, elem_length]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. - return_indices : boolean - Whether to return box indices in input data. + max_output_size : optional, int + Max number of output valid boxes for each instance. + By default all valid boxes are returned. iou_threshold : optional, float Non-maximum suppression threshold. @@ -231,16 +676,25 @@ def nms_gpu(data, top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. + coord_start : required, int + Start index of the consecutive 4 coordinates. + + score_index : optional, int + Index of the scores/confidence of boxes. + id_index : optional, int index of the class categories, -1 to disable. + return_indices : boolean + Whether to return box indices in input data. + invalid_to_bottom : optional, boolean Whether to move all valid bounding boxes to the top. Returns ------- out : tvm.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. + 3-D tensor with shape [batch_size, num_anchors, elem_length]. Example -------- @@ -253,12 +707,13 @@ def nms_gpu(data, iou_threshold = 0.7 force_suppress = True top_k = -1 - out = nms(data, valid_count, iou_threshold, force_suppress, topk) + out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold, + force_suppress=force_supress, top_k=top_k, return_indices=False) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) - f = tvm.build(s, [data, valid_count, out], "llvm") - ctx = tvm.cpu() + f = tvm.build(s, [data, valid_count, out], "cuda") + ctx = tvm.gpu(0) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) @@ -266,38 +721,62 @@ def nms_gpu(data, """ batch_size = data.shape[0] num_anchors = data.shape[1] + valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) - data_buf = api.decl_buffer( - data.shape, data.dtype, "data_buf", data_alignment=8) + score_axis = score_index score_shape = (batch_size, num_anchors) - score_tensor = tvm.compute( - score_shape, lambda i, j: data[i, j, 1], name="score_tensor") - score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, - "score_tensor_buf", data_alignment=8) + score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) + sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) - sort_tensor_dtype = "int32" - sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, + sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) - sort_tensor = \ - tvm.extern(score_shape, - [score_tensor, valid_count], - lambda ins, outs: sort_ir( - ins[0], ins[1], outs[0]), - dtype=sort_tensor_dtype, - in_buffers=[score_tensor_buf, valid_count_buf], - out_buffers=sort_tensor_buf, - name="nms_sort") + data_buf = api.decl_buffer( + data.shape, data.dtype, "data_buf", data_alignment=8) - out = \ - tvm.extern(data.shape, + out_buf = api.decl_buffer( + data.shape, data.dtype, "out_buf", data_alignment=8) + + out, box_indices = \ + tvm.extern([data.shape, score_shape], [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( - ins[0], ins[1], ins[2], outs[0], iou_threshold, - force_suppress, top_k), - dtype="float32", + ins[0], ins[1], ins[2], outs[0], outs[1], + max_output_size, iou_threshold, force_suppress, + top_k, coord_start, id_index), + dtype=[data.dtype, "int32"], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], + name="nms", tag="nms") + + if return_indices: + return box_indices + + if invalid_to_bottom: + output_buf = api.decl_buffer( + data.shape, data.dtype, "output_buf", data_alignment=8) + temp_flag_buf = api.decl_buffer( + score_shape, valid_count_dtype, "temp_flag", data_alignment=8) + temp_idx_buf = api.decl_buffer( + score_shape, valid_count_dtype, "temp_idx", data_alignment=8) + temp_flag, temp_idx = tvm.extern([score_shape, score_shape], [out], + lambda ins, outs: invalid_to_bottom_pre( + ins[0], outs[0], outs[1]), + dtype=["int32", "int32"], + in_buffers=[out_buf], + out_buffers=[temp_flag_buf, temp_idx_buf], + name="invalid_to_bottom_phase_one") + + output = tvm.extern([data.shape], [out, temp_flag, temp_idx], + lambda ins, outs: invalid_to_bottom_ir( + ins[0], ins[1], ins[2], outs[0]), + dtype=[data.dtype], + in_buffers=[out_buf, temp_flag_buf, temp_idx_buf], + out_buffers=[output_buf], + name="invalid_to_bottom", + tag="invalid_to_bottom") + return output + return out diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py new file mode 100644 index 000000000000..99ba8527cdfb --- /dev/null +++ b/topi/python/topi/cuda/sort.py @@ -0,0 +1,249 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument +"""Argsort operator """ +import tvm + +from tvm import api +from topi.sort import argsort + +def sort_ir(data, output, axis, is_ascend): + """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + + Parameters + ---------- + data: Buffer + Buffer of input data. + + output : Buffer + Output buffer of indicies of sorted tensor with same shape as data. + + axis : Int + Axis long which to sort the input tensor. + + is_ascend : Boolean + Whether to sort in ascending or descending order. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + size = 1 + axis_mul_before = 1 + axis_mul_after = 1 + shape = data.shape + if axis < 0: + axis = len(shape) + axis + for i, value in enumerate(shape, 0): + size *= value + if i < axis: + axis_mul_before *= value + elif i > axis: + axis_mul_after *= value + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + ib = tvm.ir_builder.create() + data = ib.buffer_ptr(data) + output = ib.buffer_ptr(output) + nthread_tx = max_threads + nthread_bx = size // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("vthread") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "virtual_thread", nthread_bx) + tid = bx * nthread_tx + tx + temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") + temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local") + is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) + + with ib.for_range(0, axis_mul_before) as i: + with ib.for_range(0, axis_mul_after) as j: + current_sort_num = shape[axis] + base_idx = i * shape[axis] * axis_mul_after + j + with ib.if_scope(tid < shape[axis]): + output[base_idx + tid * axis_mul_after] = tid.astype("float32") + # OddEvenTransposeSort + with ib.for_range(0, current_sort_num) as k: + with ib.if_scope(tid < (current_sort_num + 1) // 2): + offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tvm.all(is_ascend == 1, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] > data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + with ib.if_scope(tvm.all(is_ascend == 0, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] < data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) + + return ib.get() + + + +def sort_nms_ir(data, valid_count, output, axis, is_ascend): + """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + + Parameters + ---------- + data: Buffer + Buffer of input data. + + valid_count : Buffer + 1D Buffer of number of valid number of boxes. + + output : Buffer + Output buffer of indicies of sorted tensor with same shape as data. + + axis : Int + Axis long which to sort the input tensor. + + is_ascend : Boolean + Whether to sort in ascending or descending order. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + + size = 1 + axis_mul_before = 1 + axis_mul_after = 1 + shape = data.shape + if axis < 0: + axis = len(shape) + axis + for i, value in enumerate(shape, 0): + size *= value + if i < axis: + axis_mul_before *= value + elif i > axis: + axis_mul_after *= value + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + ib = tvm.ir_builder.create() + data = ib.buffer_ptr(data) + valid_count = ib.buffer_ptr(valid_count) + output = ib.buffer_ptr(output) + nthread_tx = max_threads + nthread_bx = size // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("vthread") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "virtual_thread", nthread_bx) + tid = bx * nthread_tx + tx + temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") + temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) + + with ib.for_range(0, axis_mul_before) as i: + with ib.for_range(0, axis_mul_after) as j: + current_sort_num = valid_count[i * axis_mul_after + j] + base_idx = i * shape[axis] * axis_mul_after + j + with ib.if_scope(tid < shape[axis]): + output[base_idx + tid * axis_mul_after] = tid + # OddEvenTransposeSort + with ib.for_range(0, current_sort_num) as k: + with ib.if_scope(tid < (current_sort_num + 1) // 2): + offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tvm.all(is_ascend == 1, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] > data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + with ib.if_scope(tvm.all(is_ascend == 0, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] < data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) + + return ib.get() + +@argsort.register(["cuda", "gpu"]) +def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): + """Performs sorting along the given axis and returns an array of indicies + having same shape as an input array that index data in sorted order. + + Parameters + ---------- + data: tvm.Tensor + The input array. + + valid_count : tvm.Tensor + The number of valid elements to be sorted. + + axis : int + Axis long which to sort the input tensor. + + is_ascend : boolean + Whether to sort in ascending or descending order. + + flag : boolean + Whether this argsort is used in nms operator + + Returns + ------- + out : tvm.Tensor + The output of this function. + """ + data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + if flag: + valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, + "valid_count_buf", data_alignment=4) + out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) + out = tvm.extern([data.shape], + [data, valid_count], + lambda ins, outs: sort_nms_ir( + ins[0], ins[1], outs[0], axis, is_ascend), + dtype="int32", + in_buffers=[data_buf, valid_count_buf], + out_buffers=[out_buf], + name="argsort_nms_gpu", + tag="argsort_nms_gpu") + else: + out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + out = tvm.extern([data.shape], + [data], + lambda ins, outs: sort_ir( + ins[0], outs[0], axis, is_ascend), + dtype=dtype, + in_buffers=[data_buf], + out_buffers=[out_buf], + name="argsort_gpu", + tag="argsort_gpu") + return out diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index 38b76f36801e..f7e5f94a5655 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -21,6 +21,7 @@ import tvm from tvm import api +from tvm.intrin import if_then_else, exp import topi @@ -93,12 +94,11 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): center_w = (j + offset_w) * steps_w for k in range(num_sizes + num_ratios - 1): - w = tvm.if_then_else(k < num_sizes, - size_ratio_concat[ - k] * in_height / in_width / 2.0, - size_ratio_concat[0] * in_height / in_width * - math.sqrt(size_ratio_concat[k + 1]) / 2.0) - h = tvm.if_then_else( + w = if_then_else(k < num_sizes, + size_ratio_concat[k] * in_height / in_width / 2.0, + size_ratio_concat[0] * in_height / in_width * + math.sqrt(size_ratio_concat[k + 1]) / 2.0) + h = if_then_else( k < num_sizes, size_ratio_concat[k] / 2.0, size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) count = (i * in_width * (num_sizes + num_ratios - 1) + @@ -154,8 +154,7 @@ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), out = topi.clip(out, 0, 1) return out - -def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold): +def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp_score, threshold): """Low level IR routing for transform location data preparation. Parameters @@ -166,13 +165,13 @@ def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, valid_count : Buffer Buffer of number of valid output boxes. - temp_flag : Buffer + temp_valid_count : Buffer Output intermediate result buffer - temp_id : Buffer + temp_cls_id : Buffer Output intermediate result buffer - temp_score_out : Buffer + temp_score : Buffer Output buffer threshold : float @@ -187,53 +186,53 @@ def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] - max_threads = int( - tvm.target.current_target(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() - score = ib.buffer_ptr(temp_score_out) - cls_id = ib.buffer_ptr(temp_id) - flag = ib.buffer_ptr(temp_flag) + + cls_prob = ib.buffer_ptr(cls_prob) + cls_id = ib.buffer_ptr(temp_cls_id) + valid_count = ib.buffer_ptr(valid_count) + temp_valid_count = ib.buffer_ptr(temp_valid_count) + score = ib.buffer_ptr(temp_score) + + threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold) + + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") - nthread_tx = max_threads - nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1 ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - p_cls_prob = ib.buffer_ptr(cls_prob) - p_valid_count = ib.buffer_ptr(valid_count) with ib.if_scope(tid < batch_size * num_anchors): - n = tid / num_anchors # number of batches - i = tid % num_anchors # number of anchors - score[i] = -1.0 - cls_id[i] = 0 - p_valid_count[n] = 0 - with ib.for_range(0, num_classes-1, name="k") as k: - temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i] - with ib.if_scope(temp > score[i]): - cls_id[i] = k + 1 - score[i] = temp - with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)): - cls_id[i] = 0 - with ib.if_scope(cls_id[i] > 0): - flag[i] = 1 + i = tid / num_anchors + j = tid % num_anchors + valid_count[i] = 0 + score[tid] = -1.0 + cls_id[tid] = 0 + with ib.for_range(0, num_classes - 1) as k: + temp = cls_prob[i * num_classes * num_anchors + (k + 1) * num_anchors + j] + cls_id[tid] = if_then_else(temp > score[tid], k + 1, cls_id[tid]) + score[tid] = tvm.max(temp, score[tid]) + with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)): + cls_id[tid] = 0 + with ib.if_scope(cls_id[tid] > 0): + temp_valid_count[tid] = 1 with ib.else_scope(): - flag[i] = 0 + temp_valid_count[tid] = 0 with ib.if_scope(tid < batch_size): - with ib.for_range(0, num_anchors, name="k") as k: + with ib.for_range(0, num_anchors) as k: with ib.if_scope(k > 0): - flag[tid * num_anchors + - k] += flag[tid * num_anchors + k - 1] - p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1] + temp_valid_count[tid * num_anchors + k] += \ + temp_valid_count[tid * num_anchors + k - 1] + valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors - 1] - body = ib.get() - return body + return ib.get() - -def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ - out, clip, variances, batch_size, num_classes, num_anchors): +def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score, out, \ + clip, variances, batch_size, num_anchors): """Low level IR routing for transform location in multibox_detection operator. Parameters @@ -244,13 +243,13 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ anchor : Buffer Buffer of prior anchor boxes. - temp_flag : Buffer + temp_valid_count : Buffer Intermediate result buffer. - temp_id : Buffer + temp_cls_id : Buffer Intermediate result buffer. - temp_score_in : Buffer + temp_score : Buffer Input buffer which stores intermediate results. out : Buffer @@ -265,9 +264,6 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ batch_size : int Batch size - num_classes : int - Number of classes - num_anchors : int Number of anchors @@ -293,47 +289,55 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, ph = loc[loc_base_idx + 3] ox = px * vx * aw + ax oy = py * vy * ah + ay - ow = tvm.exp(pw * vw) * aw / 2.0 - oh = tvm.exp(ph * vh) * ah / 2.0 - return tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \ - tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \ - tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \ - tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh) - - max_threads = int( - tvm.target.current_target(allow_none=False).max_num_threads) + ow = exp(pw * vw) * aw / 2.0 + oh = exp(ph * vh) * ah / 2.0 + return tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox - ow)), ox - ow), \ + tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy - oh)), oy - oh), \ + tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, ox + ow)), ox + ow), \ + tvm.if_then_else(clip, tvm.max(0.0, tvm.min(1.0, oy + oh)), oy + oh) + ib = tvm.ir_builder.create() - score = ib.buffer_ptr(temp_score_in) - cls_id = ib.buffer_ptr(temp_id) - flag = ib.buffer_ptr(temp_flag) + + loc_pred = ib.buffer_ptr(loc_pred) + anchor = ib.buffer_ptr(anchor) + temp_valid_count = ib.buffer_ptr(temp_valid_count) + cls_id = ib.buffer_ptr(temp_cls_id) + score = ib.buffer_ptr(temp_score) + out_loc = ib.buffer_ptr(out) + + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") - nthread_tx = max_threads - nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1 ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - p_loc_pred = ib.buffer_ptr(loc_pred) - p_anchor = ib.buffer_ptr(anchor) - p_out = ib.buffer_ptr(out) with ib.if_scope(tid < batch_size * num_anchors): - n = tid / num_anchors # number of batches - i = tid % num_anchors # number of anchors + i = tid / num_anchors + j = tid % num_anchors with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(tid == 0): - out_base_idx = n * num_anchors * 6 + out_base_idx = i * num_anchors * 6 + out_loc[out_base_idx] = cls_id[tid] - 1.0 + out_loc[out_base_idx + 1] = score[tid] + out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \ + out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4, + anchor, j * 4, clip, variances[0], + variances[1], variances[2], + variances[3]) with ib.else_scope(): - out_base_idx = n * num_anchors * 6 + flag[tid - 1] * 6 - p_out[out_base_idx] = cls_id[tid] - 1.0 - p_out[out_base_idx + 1] = score[tid] - p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ - p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, - p_anchor, i*4, clip, variances[0], - variances[1], variances[2], variances[3]) + out_base_idx = i * num_anchors * 6 + temp_valid_count[tid - 1] * 6 + out_loc[out_base_idx] = cls_id[tid] - 1.0 + out_loc[out_base_idx + 1] = score[tid] + out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \ + out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4, + anchor, j * 4, clip, variances[0], + variances[1], variances[2], + variances[3]) - body = ib.get() - return body + return ib.get() @multibox_transform_loc.register(["cuda", "gpu"]) @@ -372,44 +376,48 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \ 1-D tensor with shape (batch_size,), number of valid anchor boxes. """ batch_size = cls_prob.shape[0] - num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] oshape = (batch_size, num_anchors, 6) # Define data alignment for intermediate buffer valid_count_dtype = "int32" + out_loc_dtype = loc_pred.dtype + valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, "valid_count_buf", data_alignment=4) - out_buf = api.decl_buffer( - oshape, cls_prob.dtype, "out_buf", data_alignment=8) - size = num_anchors - temp_flag_buf = api.decl_buffer( - (size,), valid_count_dtype, "flag", data_alignment=8) - temp_id_buf = api.decl_buffer( - (size,), valid_count_dtype, "cls_id", data_alignment=8) + loc_pred_buf = api.decl_buffer(loc_pred.shape, loc_pred.dtype, + "loc_pred_buf", data_alignment=8) + anchor_buf = api.decl_buffer(anchor.shape, anchor.dtype, + "anchor_buf", data_alignment=8) + + temp_valid_count_buf = api.decl_buffer( + (batch_size, num_anchors,), valid_count_dtype, "temp_valid_count", data_alignment=8) + temp_cls_id_buf = api.decl_buffer( + (batch_size, num_anchors,), valid_count_dtype, "temp_cls_id", data_alignment=8) temp_score_buf = api.decl_buffer( - (size,), cls_prob.dtype, "score", data_alignment=8) + (batch_size, num_anchors,), cls_prob.dtype, "temp_score", data_alignment=8) - valid_count, temp_flag, temp_id, temp_score = \ - tvm.extern([(batch_size,), (size,), (size,), (size,)], - [cls_prob], + valid_count, temp_valid_count, temp_cls_id, temp_score = \ + tvm.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), \ + (batch_size, num_anchors,)], [cls_prob], lambda ins, outs: transform_loc_pre( ins[0], outs[0], outs[1], outs[2], outs[3], threshold), - dtype=[valid_count_dtype, - valid_count_dtype, valid_count_dtype, cls_prob.dtype], - out_buffers=[valid_count_buf, - temp_flag_buf, temp_id_buf, temp_score_buf], - tag="multibox_transform_loc_first_step") + dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype], + out_buffers=[valid_count_buf, temp_valid_count_buf, \ + temp_cls_id_buf, temp_score_buf], + tag="multibox_transform_loc_phase_one") - out = \ + out_loc = \ tvm.extern([oshape], - [loc_pred, anchor, temp_flag, temp_id, temp_score], + [loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score], lambda ins, outs: transform_loc_ir( - ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \ - variances, batch_size, num_classes, num_anchors), - dtype=[cls_prob.dtype], - out_buffers=[out_buf], + ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, variances, \ + batch_size, num_anchors), + in_buffers=[loc_pred_buf, anchor_buf, temp_valid_count_buf, \ + temp_cls_id_buf, temp_score_buf], + dtype=[out_loc_dtype], tag="multibox_transform_loc") - return [out, valid_count] + + return [out_loc, valid_count] @multibox_detection.register(["cuda", "gpu"]) @@ -453,6 +461,7 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression( - inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) + out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + iou_threshold=nms_threshold, force_suppress=force_suppress, + top_k=nms_topk, return_indices=False) return out diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 5d7bc9e00da6..78f5c1f51ec6 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -32,11 +32,15 @@ def _default_schedule(outs): def traverse(op): """inline all one-to-one-mapping operators except the last stage (output)""" - if "nms" in op.tag: - sort = op.input_tensors[1] + if op.tag in ["nms", "invalid_to_bottom"]: + if op.tag == "nms": + sort = op.input_tensors[1] + else: + out = op.input_tensors[0] + sort = s[out].op.input_tensors[1] score = s[sort].op.input_tensors[0] fused = s[score].fuse(*s[score].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads + num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads) bx, tx = s[score].split(fused, factor=num_thread) s[score].bind(bx, tvm.thread_axis("blockIdx.x")) s[score].bind(tx, tvm.thread_axis("threadIdx.x")) @@ -199,3 +203,30 @@ def schedule_get_valid_counts(outs): The computation schedule for the op. """ return _default_schedule(outs) + +@generic.schedule_argsort.register(["cuda", "gpu"]) +def schedule_argsort(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argsort + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + from .injective import _schedule_injective + def traverse(op): + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + traverse(outs[0].op) + return s diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 8450e2d4c4e2..6bf5f3a053c9 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -19,3 +19,4 @@ from .injective import * from .extern import * from .vision import * +from .sort import * diff --git a/topi/python/topi/generic/sort.py b/topi/python/topi/generic/sort.py new file mode 100644 index 000000000000..1ad088c50d04 --- /dev/null +++ b/topi/python/topi/generic/sort.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member +"""Generic vision operators""" +from __future__ import absolute_import as _abs +import tvm +from .vision import _default_schedule + +@tvm.target.generic_func +def schedule_argsort(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The indices that would sort an input array along + the given axis. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py new file mode 100644 index 000000000000..84fff8d8f0cd --- /dev/null +++ b/topi/python/topi/sort.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments +"""Argsort operator""" +import tvm +from tvm import api + +@tvm.target.generic_func +def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): + """Performs sorting along the given axis and returns an array + of indices having the same shape as an input array that index + data in sorted order. + + Parameters + ---------- + data : tvm.Tensor + The input tensor. + + valid_count : tvm.Tensor + 1-D tensor for valid number of boxes only for ssd. + + axis : optional, int + Axis along which to sort the input tensor. + By default the flattened array is used. + + is_ascend : optional, boolean + Whether to sort in ascending or descending order. + + dtype : optional, string + DType of the output indices. + + flag : optional, boolean + Whether valid_count is valid. + + Returns + ------- + out : tvm.Tensor + Sorted index tensor. + + Example + -------- + .. code-block:: python + + # An example to use argsort + dshape = (1, 5, 6) + data = tvm.placeholder(dshape, name="data") + valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") + axis = 0 + is_ascend = False + flag = False + out = argsort(data, valid_count, axis, is_ascend, flag) + np_data = np.random.uniform(dshape) + np_valid_count = np.array([4]) + s = topi.generic.schedule_argsort(out) + f = tvm.build(s, [data, valid_count, out], "llvm") + ctx = tvm.cpu() + tvm_data = tvm.nd.array(np_data, ctx) + tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) + f(tvm_data, tvm_valid_count, tvm_out) + """ + data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + if flag: + valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, + "valid_count_buf", data_alignment=4) + out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) + out = \ + tvm.extern(data.shape, + [data, valid_count], + lambda ins, outs: tvm.call_packed( + "tvm.contrib.sort.argsort_nms", ins[0], ins[1], + outs[0], axis, is_ascend), + dtype="int32", + in_buffers=[data_buf, valid_count_buf], + out_buffers=out_buf, + name="argsort_nms_cpu", + tag="argsort_nms_cpu") + else: + out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + out = \ + tvm.extern(data.shape, + [data], + lambda ins, outs: tvm.call_packed( + "tvm.contrib.sort.argsort", ins[0], + outs[0], axis, is_ascend), + dtype=dtype, + in_buffers=[data_buf], + out_buffers=out_buf, + name="argsort_cpu", + tag="argsort_cpu") + return out diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index d8b15aac42c6..979565d31662 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -18,7 +18,8 @@ """Non-maximum suppression operator""" import tvm -from tvm import api, hybrid +from tvm import hybrid +from ..sort import argsort @hybrid.script def hybrid_rearrange_out(data): @@ -129,7 +130,7 @@ def get_valid_counts(data, score_threshold=0): @hybrid.script def hybrid_nms(data, sorted_index, valid_count, max_output_size, iou_threshold, force_suppress, - top_k, id_index): + top_k, coord_start, id_index): """Hybrid routing for non-maximum suppression. Parameters @@ -158,6 +159,9 @@ def hybrid_nms(data, sorted_index, valid_count, top_k : tvm.const Keep maximum top k detections before nms, -1 for no limit. + coord_start : tvm.const + Start index of the consecutive 4 coordinates. + id_index : tvm.const index of the class categories, -1 to disable. @@ -208,7 +212,7 @@ def hybrid_nms(data, sorted_index, valid_count, batch_idx = i box_a_idx = j box_b_idx = k - box_start_idx = 2 + box_start_idx = coord_start a_t = output[batch_idx, box_a_idx, box_start_idx + 1] a_b = output[batch_idx, box_a_idx, box_start_idx + 3] a_l = output[batch_idx, box_a_idx, box_start_idx] @@ -252,7 +256,8 @@ def hybrid_nms(data, sorted_index, valid_count, @tvm.target.generic_func def non_max_suppression(data, valid_count, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, - id_index=0, return_indices=True, invalid_to_bottom=False): + coord_start=2, score_index=1, id_index=0, + return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters @@ -278,6 +283,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1, top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. + coord_start : required, int + Start index of the consecutive 4 coordinates. + + score_index: optional, int + Index of the scores/confidence of boxes. + id_index : optional, int index of the class categories, -1 to disable. @@ -317,32 +328,16 @@ def non_max_suppression(data, valid_count, max_output_size=-1, """ batch_size = data.shape[0] num_anchors = data.shape[1] - valid_count_dtype = "int32" - valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, - "valid_count_buf", data_alignment=4) - score_axis = 1 + score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) - score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, - "score_tensor_buf", data_alignment=8) - sort_tensor_dtype = "int32" - sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, - "sort_tensor_buf", data_alignment=8) - sort_tensor = \ - tvm.extern(score_shape, - [score_tensor, valid_count], - lambda ins, outs: tvm.call_packed( - "tvm.contrib.sort.argsort", ins[0], ins[1], - outs[0], score_axis, True), - dtype=sort_tensor_dtype, - in_buffers=[score_tensor_buf, valid_count_buf], - out_buffers=sort_tensor_buf, - name="nms_sort") + sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) out, box_indices = hybrid_nms(data, sort_tensor, valid_count, tvm.const(max_output_size, dtype="int32"), tvm.const(iou_threshold, dtype="float32"), tvm.const(force_suppress, dtype="bool"), tvm.const(top_k, dtype="int32"), + tvm.const(coord_start, dtype="int32"), tvm.const(id_index, dtype="int32")) if not return_indices and invalid_to_bottom: out = hybrid_rearrange_out(out) diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index 799669003753..ca1b4a9eb268 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -308,7 +308,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression(inter_out[0], inter_out[1], -1, - nms_threshold, force_suppress, nms_topk, - return_indices=False) + out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + iou_threshold=nms_threshold, force_suppress=force_suppress, + top_k=nms_topk, return_indices=False) return out diff --git a/topi/tests/python/test_topi_sort.py b/topi/tests/python/test_topi_sort.py new file mode 100644 index 000000000000..3a2c9c2e4980 --- /dev/null +++ b/topi/tests/python/test_topi_sort.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for vision package""" +from __future__ import print_function +import math +import numpy as np +import tvm +import topi +import topi.testing + +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple +from topi import argsort + +def test_argsort(): + dshape = (1, 8) + valid_count_shape = (2,) + data = tvm.placeholder(dshape, name="data", dtype="float32") + valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") + np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype) + np_valid_count = np.array([4]).astype(valid_count.dtype) + np_result = np.argsort(-np_data) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False) + s = topi.generic.schedule_argsort(out) + + tvm_data = tvm.nd.array(np_data, ctx) + tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx) + f = tvm.build(s, [data, valid_count, out], device) + f(tvm_data, tvm_valid_count, tvm_out) + tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0) + + for device in ['llvm', 'cuda', 'opencl']: + check_device(device) + + +if __name__ == "__main__": + test_argsort() diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 6bb57b541c88..483f3a641c70 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -66,7 +66,7 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) - for device in ['llvm']: + for device in ['llvm', 'cuda', 'opencl']: check_device(device) @@ -124,7 +124,7 @@ def check_device(device): f(tvm_data, tvm_valid_count, tvm_indices_out) tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) - for device in ['llvm']: + for device in ['llvm', 'cuda', 'opencl']: check_device(device) @@ -231,7 +231,7 @@ def check_device(device): f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4) - for device in ['llvm', 'opencl']: + for device in ['llvm', 'opencl', 'cuda']: check_device(device) @@ -275,7 +275,7 @@ def check_device(device): f(tvm_a, tvm_rois, tvm_b) tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3) - for device in ['llvm', 'cuda']: + for device in ['llvm', 'cuda', 'opencl']: check_device(device) diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py index fe84283ad191..ff7691c7bf55 100644 --- a/tutorials/frontend/deploy_ssd_gluoncv.py +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -18,6 +18,7 @@ Deploy Single Shot Multibox Detector(SSD) model =============================================== **Author**: `Yao Wang `_ +`Leyuan Wang `_ This article is an introductory tutorial to deploy SSD models with TVM. We will use GluonCV pre-trained SSD model and convert it to Relay IR @@ -37,30 +38,29 @@ # ------------------------------ # .. note:: # -# Currently we support compiling SSD on CPU only. -# GPU support is in progress. +# We support compiling SSD on bot CPUs and GPUs now. # # To get best inference performance on CPU, change # target argument according to your device and # follow the :ref:`tune_relay_x86` to tune x86 CPU and # :ref:`tune_relay_arm` for arm cpu. # +# To get best performance fo SSD on Intel graphics, +# change target argument to 'opencl -device=intel_graphics' +# # SSD with VGG as body network is not supported yet since # x86 conv2d schedule doesn't support dilation. supported_model = [ - 'ssd_512_resnet18_v1_voc', - 'ssd_512_resnet18_v1_coco', 'ssd_512_resnet50_v1_voc', 'ssd_512_resnet50_v1_coco', 'ssd_512_resnet101_v2_voc', - 'ssd_512_mobilenet1_0_voc', - 'ssd_512_mobilenet1_0_coco', + 'ssd_512_mobilenet1.0_voc', + 'ssd_512_mobilenet1.0_coco', ] -model_name = "ssd_512_resnet50_v1_voc" +model_name = supported_model[0] dshape = (1, 3, 512, 512) -dtype = "float32" target_list = ctx_list() ###################################################################### @@ -76,7 +76,7 @@ block = model_zoo.get_model(model_name, pretrained=True) -def compile(target): +def build(target): net, params = relay.frontend.from_mxnet(block, {"data": dshape}) with relay.build_config(opt_level=3): graph, lib, params = relay.build(net, target, params=params) @@ -98,10 +98,7 @@ def run(graph, lib, params, ctx): return class_IDs, scores, bounding_boxs for target, ctx in target_list: - if target == "cuda": - print("GPU not supported yet, skip.") - continue - graph, lib, params = compile(target) + graph, lib, params = build(target) class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx) ######################################################################