Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
84 changes: 83 additions & 1 deletion python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,88 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
return [valid_count, out, out_indices]


def rearrange_indices_out_ir(data, output, valid_box_count):
"""Low level IR to get rearrange_indices_out.
Parameters
----------
data : Buffer
Input data. 2-D Buffer with shape [batch_size, num_anchors].

output: Buffer
2-D Buffer with shape [batch_size, num_anchors].

valid_box_count : Buffer
2-D Buffer with shape [batch_size, 1].

Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
output = ib.buffer_ptr(output)
valid_box_count = ib.buffer_ptr(valid_box_count)

nthread_tx = batch_size
nthread_bx = 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = tx

neg_one = tvm.tir.const(-1, dtype=output.dtype)
valid_box_count[tid] = 0
with ib.for_range(0, num_anchors) as anchor_ind:
output[tid * num_anchors + anchor_ind] = neg_one
with ib.for_range(0, num_anchors) as anchor_ind:
with ib.if_scope(data[tid * num_anchors + anchor_ind] >= 0):
output[tid * num_anchors + valid_box_count[tid]] = data[tid * num_anchors + anchor_ind]
valid_box_count[tid] = valid_box_count[tid] + 1
return ib.get()


def rearrange_indices_out(data):
"""Rearrange nms output to move all valid entries to top.

Parameters
----------
data : tvm.te.Tensor or numpy NDArray
NMS output. 2-D
tensor with shape [batch_size, num_anchors].

Returns
-------
output : tvm.te.Tensor or numpy NDArray
2-D tensor with shape [batch_size, num_anchors].

valid_box_count : tvm.te.Tensor or numpy NDArray
Tensor with shape [batch_size, 1], indicates
the valid number of boxes.
"""
batch_size = data.shape[0]
data_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8)
out_indices_buf = tvm.tir.decl_buffer(
data.shape, data.dtype, "out_indices_buf", data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(
(batch_size, 1), "int32", "valid_count_buf", data_alignment=8)

output, valid_box_count = te.extern([out_indices_buf.shape, valid_count_buf.shape],
[data],
lambda ins, outs: rearrange_indices_out_ir(
ins[0], outs[0], outs[1]),
in_buffers=[data_buf],
out_buffers=[out_indices_buf, valid_count_buf],
name="rearrange_indices_out",
tag="rearrange_indices_out_gpu")
return [output, valid_box_count]


def nms_ir(
data,
sorted_index,
Expand Down Expand Up @@ -522,6 +604,6 @@ def non_max_suppression(
)
# TODO(yongwww): Update cuda nms to be consistent with cpu version
if return_indices:
return box_indices
return rearrange_indices_out(box_indices)

return out
79 changes: 79 additions & 0 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,84 @@ def verify_nms(
)


@tvm.testing.uses_gpu
def test_non_max_suppression_gpu():
def verify_nms(
x0_data,
x1_data,
x2_data,
x3_data,
dshape,
ref_res,
ref_indices_res,
iou_threshold=0.5,
force_suppress=True,
top_k=-1,
check_type_only=False
):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32"))
x2 = relay.var("x2", relay.ty.TensorType((dshape[0], dshape[1]), "int32"))
x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
z_indices = relay.vision.non_max_suppression(x0, x1, x2, x3, \
iou_threshold=iou_threshold, force_suppress=force_suppress, \
top_k=top_k, return_indices=True)
if isinstance(z_indices, relay.expr.TupleWrapper):
z_indices = z_indices.astuple()
zz_indices = run_infer_type(z_indices)

func_indices = relay.Function([x0, x1, x2, x3], z_indices)
func_indices = run_infer_type(func_indices)
for target, ctx in tvm.testing.enabled_targets():
if target != 'cuda':
continue
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_indices_res2[0].asnumpy(), ref_indices_res, rtol=1e-5)

# data after get_valid_counts
np_data = np.array(
[
[
[0, 0.8, 1, 20, 25, 45],
[1, 0.7, 2, 21, 26, 45],
[-1, -1, -1, -1, -1, -1],
[2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]
]
]
).astype("float32")
np_indices = np.array([[0, 1, -1, 3, 4]]).astype("int32")
np_valid_count = np.array([4]).astype("int32")
np_max_output_size = -1
num_anchors = 5
dshape = (1, num_anchors, 6)
np_result = np.array(
[
[
[2, 0.9, 35, 61, 52, 79],
[0, 0.8, 1, 20, 25, 45],
[-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]
]
]
)
np_indices_result = np.array([[3, 0, -1, -1, -1]])
verify_nms(
np_data,
np_valid_count,
np_indices,
np_max_output_size,
dshape,
np_result,
np_indices_result,
top_k=2,
)

@tvm.testing.uses_gpu
def test_multibox_transform_loc():
def test_default_value():
Expand Down Expand Up @@ -1156,6 +1234,7 @@ def verify_grid_sample(data_shape, grid_shape):
test_yolo_reorg_infer_shape()
test_yolo_reorg()
test_non_max_suppression()
test_non_max_suppression_gpu()
test_deformable_conv2d()
test_depth_to_space()
test_space_to_depth()
Expand Down
8 changes: 2 additions & 6 deletions tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,8 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)

tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
if device == "llvm":
f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
else:
f = tvm.build(indices_s, [data, valid_count, indices, indices_out], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)

for device in ["llvm", "cuda", "opencl"]:
Expand Down