diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index ed6e8f086a0d..82625ffac557 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -54,6 +54,68 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) +def rearrange_indices_out_ir(data, out, valid_box_count): + """Hybrid routine to rearrange nms output to + move all valid entries to top. + + Parameters + ---------- + data : tvm.te.Tensor or numpy NDArray + tensor with shape [batch_size, num_anchors]. + + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + + ib = tvm.tir.ir_builder.create() + data = ib.buffer_ptr(data) + out = ib.buffer_ptr(out) + valid_box_count = ib.buffer_ptr(valid_box_count) + + one_count = tvm.tir.const(1, dtype="int32") + atomic_add_return = ib.allocate( + valid_box_count.dtype, (batch_size,), name="atomic_add_return", scope="local" + ) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + len_inner_for = (batch_size * num_anchors) // nthread_tx + 2 + + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + with ib.for_range(0, len_inner_for, name="i") as i: + idx = tx * len_inner_for + i + batch_idx = idxd(idx, num_anchors) + with ib.if_scope(idx < batch_size): + valid_box_count[idx] = 0 + with ib.if_scope(idx < batch_size * num_anchors): + with ib.if_scope(data[idx] >= 0): + atomic_add_return[batch_idx] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]), + one_count, + ) + out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = data[idx] + with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] < -num_anchors)): + atomic_add_return[batch_idx] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]), + one_count, + ) + out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0 + + with ib.if_scope(idxm(idx, num_anchors) >= valid_box_count[batch_idx]): + out[idx] = -1 + + return ib.get() + + def get_valid_counts_ir( data, valid_count, out, out_indices, score_threshold, id_index, score_index ): @@ -198,6 +260,7 @@ def nms_ir( data, sorted_index, valid_count, + indices, out, box_indices, max_output_size, @@ -207,6 +270,7 @@ def nms_ir( coord_start, id_index, score_index, + return_indices, ): """Low level IR routing for transform location in multibox_detection operator. @@ -285,6 +349,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) + indices = ib.buffer_ptr(indices) num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) @@ -379,6 +444,12 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.else_scope(): num_valid_boxes[0] += 1 + if return_indices: + with ib.if_scope(j < valid_count[i]): + box_idx = box_indices[i * num_anchors + j] + with ib.if_scope(box_idx >= 0): + box_indices[i * num_anchors + j] = indices[i * num_anchors + box_idx] + return ib.get() @@ -502,14 +573,16 @@ def non_max_suppression( ) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) out, box_indices = te.extern( [data.shape, score_shape], - [data, sort_tensor, valid_count], + [data, sort_tensor, valid_count, indices], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], + ins[3], outs[0], outs[1], max_output_size, @@ -519,14 +592,26 @@ def non_max_suppression( coord_start, id_index, score_index, + return_indices, ), dtype=[data.dtype, "int32"], - in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], + in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], name="nms", tag="nms", ) - # TODO(yongwww): Update cuda nms to be consistent with cpu version + if return_indices: - return box_indices + out_buf = tvm.tir.decl_buffer( + box_indices.shape, box_indices.dtype, "out_buf", data_alignment=8 + ) + return te.extern( + [box_indices.shape, (batch_size, 1)], + [box_indices], + lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]), + dtype=[box_indices.dtype, valid_count.dtype], + in_buffers=[out_buf], + name="rearrange_indices_out", + tag="rearrange_indices_out", + ) return out diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 76e1808698e5..b076fde9ac6e 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -52,15 +52,16 @@ def hybrid_rearrange_box_out(data, one, batch_size, num_anchors): """ elem_length = data.shape[2] output = output_tensor((batch_size, num_anchors, elem_length), data.dtype) + valid_indices = allocate((batch_size,), "int32") for i in parallel(batch_size): - valid_idx = 0 + valid_indices[i] = 0 for j in range(num_anchors): if data[i, j, 0] >= 0: for k in range(elem_length): - output[i, valid_idx, k] = data[i, j, k] - valid_idx += 1 - if j >= valid_idx: + output[i, valid_indices[i], k] = data[i, j, k] + valid_indices[i] += 1 + if j >= valid_indices[i]: for k in range(elem_length): output[i, j, k] = -one return output @@ -100,19 +101,20 @@ def hybrid_rearrange_indices_out(data, one, batch_size, num_anchors): """ valid_box_count = output_tensor((batch_size, 1), "int32") output = output_tensor((batch_size, num_anchors), data.dtype) + valid_indices = allocate((batch_size,), "int32") for i in parallel(batch_size): - valid_idx = 0 + valid_indices[i] = 0 for j in range(num_anchors): if data[i, j] >= 0: - output[i, valid_idx] = data[i, j] - valid_idx += 1 + output[i, valid_indices[i]] = data[i, j] + valid_indices[i] += 1 if data[i, j] > num_anchors or data[i, j] < -num_anchors: - output[i, valid_idx] = 0 - valid_idx += 1 - if j >= valid_idx: + output[i, valid_indices[i]] = 0 + valid_indices[i] += 1 + if j >= valid_indices[i]: output[i, j] = -one - valid_box_count[i, 0] = valid_idx + valid_box_count[i, 0] = valid_indices[i] return output, valid_box_count diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 546973704fea..3fd8837a885e 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -25,6 +25,8 @@ from utils.assert_diagnostic import DiagnosticTesting import tvm.topi.testing +import os + def int32(val): return relay.const(val, "int32") @@ -38,27 +40,43 @@ def any_dims(ndim): def check_result( - args, mod, expected, flatten=False, assert_shape=False, only_vm=False, targets=None + args, + mod, + expected, + flatten=False, + assert_shape=False, + only_vm=False, + targets=None, + disable_targets=None, ): + if not isinstance(expected, list): + expected = [expected] for kind in ["debug", "vm"]: targets = targets or tvm.testing.enabled_targets() for tgt, ctx in targets: + if disable_targets and tgt in disable_targets: + continue if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type): continue ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt) result = ex.evaluate()(*args) - result = result.asnumpy() - if assert_shape: - assert result.shape == expected, "Shape mismatch: expect %s but got %s." % ( - str(expected), - str(result.shape), - ) - return + if isinstance(result, tvm.runtime.container.ADT): + result = [r.asnumpy() for r in result] + else: + result = [result.asnumpy()] - if flatten: - result = result.flatten() - expected = expected.flatten() - tvm.testing.assert_allclose(result, expected, atol=2e-6) + for r, e in zip(result, expected): + if assert_shape: + assert r.shape == e, "Shape mismatch: expect %s but got %s." % ( + str(e), + str(r), + ) + return + + if flatten: + r = r.flatten() + e = e.flatten() + tvm.testing.assert_allclose(r, e, atol=2e-6) def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): @@ -1370,5 +1388,54 @@ def test_any_where(): ) +# TODO(kevinthesun): enable gpu test when Thrust is available in ci. +# @tvm.testing.uses_gpu +def test_non_max_suppression(): + x0 = relay.var("x0", relay.ty.TensorType((1, relay.Any(), 6), "float32")) + x1 = relay.var("x1", relay.ty.TensorType((1,), "int32")) + x2 = relay.var("x2", relay.ty.TensorType((1, relay.Any()), "int32")) + x3 = relay.var("x3", relay.ty.TensorType((), "int32")) + z = relay.vision.non_max_suppression( + x0, + x1, + x2, + x3, + iou_threshold=0.5, + force_suppress=True, + top_k=2, + return_indices=True, + invalid_to_bottom=False, + ) + z = z.astuple() + func = relay.Function([x0, x1, x2, x3], z) + mod = tvm.IRModule() + mod["main"] = func + + np_data = np.array( + [ + [ + [0, 0.8, 1, 20, 25, 45], + [1, 0.7, 30, 60, 50, 80], + [0, 0.4, 4, 21, 19, 40], + [2, 0.9, 35, 61, 52, 79], + [1, 0.5, 100, 60, 70, 110], + ] + ] + ).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32") + np_max_output_size = -1 + np_indices_result = np.array([[4, 0, -1, -1, -1]]) + np_valid_box_count = np.array([[2]]).astype("int32") + + check_result( + [np_data, np_valid_count, np_indices, np_max_output_size], + mod, + [np_indices_result, np_valid_box_count], + only_vm=False, + disable_targets=["nvptx"], + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 5a5a12c9efe0..9e9aaf842669 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -393,8 +393,8 @@ def verify_nms( intrp2 = relay.create_executor("debug", ctx=ctx, target=target) op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) - if target == "cuda": - return + if target == "nvptx": + continue op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5) op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 22c9045fd457..6d6353eebce6 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -202,7 +202,7 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx) - if device == "llvm": + if device in ["llvm", "cuda"]: f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device) f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) else: