Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,68 @@ def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)


def rearrange_indices_out_ir(data, out, valid_box_count):
"""Hybrid routine to rearrange nms output to
move all valid entries to top.

Parameters
----------
data : tvm.te.Tensor or numpy NDArray
tensor with shape [batch_size, num_anchors].


Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]

ib = tvm.tir.ir_builder.create()
data = ib.buffer_ptr(data)
out = ib.buffer_ptr(out)
valid_box_count = ib.buffer_ptr(valid_box_count)

one_count = tvm.tir.const(1, dtype="int32")
atomic_add_return = ib.allocate(
valid_box_count.dtype, (batch_size,), name="atomic_add_return", scope="local"
)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
len_inner_for = (batch_size * num_anchors) // nthread_tx + 2

idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

with ib.for_range(0, len_inner_for, name="i") as i:
idx = tx * len_inner_for + i
batch_idx = idxd(idx, num_anchors)
with ib.if_scope(idx < batch_size):
valid_box_count[idx] = 0
with ib.if_scope(idx < batch_size * num_anchors):
with ib.if_scope(data[idx] >= 0):
atomic_add_return[batch_idx] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]),
one_count,
)
out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = data[idx]
with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] < -num_anchors)):
atomic_add_return[batch_idx] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]),
one_count,
)
out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0

with ib.if_scope(idxm(idx, num_anchors) >= valid_box_count[batch_idx]):
out[idx] = -1

return ib.get()


def get_valid_counts_ir(
data, valid_count, out, out_indices, score_threshold, id_index, score_index
):
Expand Down Expand Up @@ -198,6 +260,7 @@ def nms_ir(
data,
sorted_index,
valid_count,
indices,
out,
box_indices,
max_output_size,
Expand All @@ -207,6 +270,7 @@ def nms_ir(
coord_start,
id_index,
score_index,
return_indices,
):
"""Low level IR routing for transform location in multibox_detection operator.

Expand Down Expand Up @@ -285,6 +349,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out)
box_indices = ib.buffer_ptr(box_indices)
indices = ib.buffer_ptr(indices)
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
Expand Down Expand Up @@ -379,6 +444,12 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
with ib.else_scope():
num_valid_boxes[0] += 1

if return_indices:
with ib.if_scope(j < valid_count[i]):
box_idx = box_indices[i * num_anchors + j]
with ib.if_scope(box_idx >= 0):
box_indices[i * num_anchors + j] = indices[i * num_anchors + box_idx]

return ib.get()


Expand Down Expand Up @@ -502,14 +573,16 @@ def non_max_suppression(
)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

out, box_indices = te.extern(
[data.shape, score_shape],
[data, sort_tensor, valid_count],
[data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
ins[1],
ins[2],
ins[3],
outs[0],
outs[1],
max_output_size,
Expand All @@ -519,14 +592,26 @@ def non_max_suppression(
coord_start,
id_index,
score_index,
return_indices,
),
dtype=[data.dtype, "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
)
# TODO(yongwww): Update cuda nms to be consistent with cpu version

if return_indices:
return box_indices
out_buf = tvm.tir.decl_buffer(
box_indices.shape, box_indices.dtype, "out_buf", data_alignment=8
)
return te.extern(
[box_indices.shape, (batch_size, 1)],
[box_indices],
lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]),
dtype=[box_indices.dtype, valid_count.dtype],
in_buffers=[out_buf],
name="rearrange_indices_out",
tag="rearrange_indices_out",
)

return out
24 changes: 13 additions & 11 deletions python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ def hybrid_rearrange_box_out(data, one, batch_size, num_anchors):
"""
elem_length = data.shape[2]
output = output_tensor((batch_size, num_anchors, elem_length), data.dtype)
valid_indices = allocate((batch_size,), "int32")

for i in parallel(batch_size):
valid_idx = 0
valid_indices[i] = 0
for j in range(num_anchors):
if data[i, j, 0] >= 0:
for k in range(elem_length):
output[i, valid_idx, k] = data[i, j, k]
valid_idx += 1
if j >= valid_idx:
output[i, valid_indices[i], k] = data[i, j, k]
valid_indices[i] += 1
if j >= valid_indices[i]:
for k in range(elem_length):
output[i, j, k] = -one
return output
Expand Down Expand Up @@ -100,19 +101,20 @@ def hybrid_rearrange_indices_out(data, one, batch_size, num_anchors):
"""
valid_box_count = output_tensor((batch_size, 1), "int32")
output = output_tensor((batch_size, num_anchors), data.dtype)
valid_indices = allocate((batch_size,), "int32")

for i in parallel(batch_size):
valid_idx = 0
valid_indices[i] = 0
for j in range(num_anchors):
if data[i, j] >= 0:
output[i, valid_idx] = data[i, j]
valid_idx += 1
output[i, valid_indices[i]] = data[i, j]
valid_indices[i] += 1
if data[i, j] > num_anchors or data[i, j] < -num_anchors:
output[i, valid_idx] = 0
valid_idx += 1
if j >= valid_idx:
output[i, valid_indices[i]] = 0
valid_indices[i] += 1
if j >= valid_indices[i]:
output[i, j] = -one
valid_box_count[i, 0] = valid_idx
valid_box_count[i, 0] = valid_indices[i]

return output, valid_box_count

Expand Down
91 changes: 79 additions & 12 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from utils.assert_diagnostic import DiagnosticTesting
import tvm.topi.testing

import os


def int32(val):
return relay.const(val, "int32")
Expand All @@ -38,27 +40,43 @@ def any_dims(ndim):


def check_result(
args, mod, expected, flatten=False, assert_shape=False, only_vm=False, targets=None
args,
mod,
expected,
flatten=False,
assert_shape=False,
only_vm=False,
targets=None,
disable_targets=None,
):
if not isinstance(expected, list):
expected = [expected]
for kind in ["debug", "vm"]:
targets = targets or tvm.testing.enabled_targets()
for tgt, ctx in targets:
if disable_targets and tgt in disable_targets:
continue
if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type):
continue
ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt)
result = ex.evaluate()(*args)
result = result.asnumpy()
if assert_shape:
assert result.shape == expected, "Shape mismatch: expect %s but got %s." % (
str(expected),
str(result.shape),
)
return
if isinstance(result, tvm.runtime.container.ADT):
result = [r.asnumpy() for r in result]
else:
result = [result.asnumpy()]

if flatten:
result = result.flatten()
expected = expected.flatten()
tvm.testing.assert_allclose(result, expected, atol=2e-6)
for r, e in zip(result, expected):
if assert_shape:
assert r.shape == e, "Shape mismatch: expect %s but got %s." % (
str(e),
str(r),
)
return

if flatten:
r = r.flatten()
e = e.flatten()
tvm.testing.assert_allclose(r, e, atol=2e-6)


def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
Expand Down Expand Up @@ -1370,5 +1388,54 @@ def test_any_where():
)


# TODO(kevinthesun): enable gpu test when Thrust is available in ci.
# @tvm.testing.uses_gpu
def test_non_max_suppression():
x0 = relay.var("x0", relay.ty.TensorType((1, relay.Any(), 6), "float32"))
x1 = relay.var("x1", relay.ty.TensorType((1,), "int32"))
x2 = relay.var("x2", relay.ty.TensorType((1, relay.Any()), "int32"))
x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
z = relay.vision.non_max_suppression(
x0,
x1,
x2,
x3,
iou_threshold=0.5,
force_suppress=True,
top_k=2,
return_indices=True,
invalid_to_bottom=False,
)
z = z.astuple()
func = relay.Function([x0, x1, x2, x3], z)
mod = tvm.IRModule()
mod["main"] = func

np_data = np.array(
[
[
[0, 0.8, 1, 20, 25, 45],
[1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40],
[2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110],
]
]
).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
np_max_output_size = -1
np_indices_result = np.array([[4, 0, -1, -1, -1]])
np_valid_box_count = np.array([[2]]).astype("int32")

check_result(
[np_data, np_valid_count, np_indices, np_max_output_size],
mod,
[np_indices_result, np_valid_box_count],
only_vm=False,
disable_targets=["nvptx"],
)


if __name__ == "__main__":
pytest.main([__file__])
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ def verify_nms(
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
if target == "cuda":
return
if target == "nvptx":
continue
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)

tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
if device == "llvm":
if device in ["llvm", "cuda"]:
f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
else:
Expand Down