Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from . import stmt as _stmt
from . import expr as _expr
from . import op


class WithScope(object):
Expand Down Expand Up @@ -200,6 +201,9 @@ def scope_attr(self, node, attr_key, value):
node = _expr.StringImm(node)
if isinstance(value, string_types):
value = _expr.StringImm(value)
# thread_extent could be zero for dynamic workloads
if attr_key == "thread_extent":
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
Expand Down
279 changes: 98 additions & 181 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,68 +51,8 @@ def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)


def rearrange_indices_out_ir(data, output, valid_box_count):
"""Hybrid routine to rearrange nms output to
move all valid entries to top.

Parameters
----------
data : tvm.te.Tensor or numpy NDArray
NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6] or
[batch_size, num_anchors, 5], or 2-D
tensor with shape [batch_size, num_anchors].

one: tvm.tir.const
Constant one with the same dtype as data.

batch_size: tvm.tir.IntImm or tvm.tir.Var
Batch size. We need to pass it in since hybrid script doesn't support
binding variable to symbolic dim.

num_anchors: tvm.tir.IntImm or tvm.tir.Var
Number of anchors.

Returns
-------
output : tvm.te.Tensor or numpy NDArray
2-D tensor with shape [batch_size, num_anchors].

valid_box_count : tvm.te.Tensor or numpy NDArray
Tensor with shape [batch_size, 1], indicates
the valid number of boxes.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
valid_box_count = ib.buffer_ptr(valid_box_count)
output = ib.buffer_ptr(output)

with ib.new_scope():
i = te.thread_axis("blockIdx.x")
ib.scope_attr(i, "thread_extent", batch_size)
valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
valid_idx[0] = 0
with ib.for_range(0, num_anchors, name="j") as j:
with ib.if_scope(data[i, j] >= 0):
with ib.if_scope(data[i, j] > num_anchors):
output[i, valid_idx[0]] = 0
valid_idx[0] = valid_idx[0] + 1
with ib.else_scope():
output[i, valid_idx[0]] = data[i, j]
valid_idx[0] = valid_idx[0] + 1
with ib.else_scope():
with ib.if_scope(data[i, j] < -num_anchors):
output[i, valid_idx[0]] = 0
valid_idx[0] = valid_idx[0] + 1
with ib.if_scope(j >= valid_idx[0]):
output[i, j] = -1
valid_box_count[i, 0] = valid_idx[0]

return ib.get()
def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)


def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index):
Expand Down Expand Up @@ -400,6 +340,7 @@ def nms_ir(
indices,
out,
box_indices,
num_valid_boxes,
max_output_size,
iou_threshold,
force_suppress,
Expand Down Expand Up @@ -430,7 +371,15 @@ def nms_ir(
is not used before non_max_suppression.

out : Buffer
Output buffer.
Output buffer, to be filled with sorted boxes.

box_indices : Buffer
A indices tensor mapping sorted indices to original indices
This is the first output of NMS when return_indices=True.

num_valid_boxes : Buffer
Record the number of boxes that have survived IOU tests.
This is the second output of NMS when return_indices=True.

max_output_size : int
Max number of output valid boxes for each instance.
Expand Down Expand Up @@ -509,6 +458,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
sorted_index = ib.buffer_ptr(sorted_index)
valid_count = ib.buffer_ptr(valid_count)
indices = ib.buffer_ptr(indices)
num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
out = ib.buffer_ptr(out)
box_indices = ib.buffer_ptr(box_indices)

Expand All @@ -523,132 +473,111 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = by
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Reorder output
nkeep = if_then_else(
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]
)
with ib.for_range(0, nkeep) as j:
j = bx * max_threads + tx
with ib.if_scope(j < num_anchors):
box_indices[i * num_anchors + j] = -1
with ib.if_scope(j < nkeep):
# Fill in out with sorted boxes
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = data[
(base_idx + sorted_index[i * num_anchors + j] * box_data_length + k)
]
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
with ib.for_range(0, valid_count[i] - nkeep) as j:
with ib.else_scope():
# Indices > nkeep are discarded
with ib.if_scope(j < num_anchors):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
box_indices[i * num_anchors + (j + nkeep)] = -1
out[(base_idx + j * box_data_length + k)] = -1.0
with ib.else_scope():
with ib.if_scope(j < valid_count[i]):
with ib.for_range(0, box_data_length) as k:
offset = base_idx + j * box_data_length + k
out[offset] = data[offset]
box_indices[i * num_anchors + j] = j

with ib.new_scope():
nthread_by = batch_size
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
i = by
base_idx = i * num_anchors * box_data_length
num_valid_boxes_local = ib.allocate(
"int32", (1,), name="num_valid_boxes_local", scope="local"
)
num_valid_boxes_local[0] = 0

def nms_inner_loop(ib, j):
offset_j = j * box_data_length

with ib.for_range(0, j) as k:
offset_k = k * box_data_length

with ib.if_scope(
tvm.tir.all(
out[base_idx + offset_j + score_index] > -1.0, # if already surpressed
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index]
== out[base_idx + offset_j + id_index],
),
)
):
iou = calculate_overlap(
out,
base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_j + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0

# Has the box j survived IOU tests?
with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
# When return_indices is False, no need to populate box_indices
if return_indices:
orig_idx = sorted_index[i * num_anchors + j]
box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx]
num_valid_boxes_local[0] += 1

if isinstance(max_output_size, int):
max_output_size = tvm.tir.const(max_output_size)

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.for_range(0, j) as k:
offset_k = k * box_data_length
with ib.if_scope(
tvm.tir.all(
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0),
)
):
offset_j = j * box_data_length
with ib.if_scope(
tvm.tir.all(
j > k,
out[base_idx + offset_k + score_index] > 0,
tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0),
tvm.tir.any(
force_suppress > 0,
id_index < 0,
out[base_idx + offset_k + id_index]
== out[base_idx + offset_j + id_index],
),
)
):
iou = calculate_overlap(
out,
base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start,
)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_j + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_by = batch_size
nthread_bz = box_data_length
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
tid = bx * max_threads + tx
i = by
j = tid
k = bz
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
pass
with ib.else_scope():
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
box_indices[i * num_anchors + j] = j

with ib.new_scope():
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
i = bx
base_idx = i * num_anchors * box_data_length
# Set invalid entry to be -1
with ib.for_range(0, num_anchors - valid_count[i]) as j:
with ib.for_range(0, box_data_length) as k:
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
box_indices[i * num_anchors + j + valid_count[i]] = -1
# Only return max_output_size number of valid boxes
num_valid_boxes[0] = 0
with ib.if_scope(max_output_size > 0):
with ib.for_range(0, valid_count[i]) as j:
offset_j = j * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0):
with ib.if_scope(num_valid_boxes[0] == max_output_size):
with ib.for_range(0, box_data_length) as k:
out[base_idx + offset_j + k] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.if_scope(
tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0)
):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we already reach max_output_size boxes
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
num_valid_boxes[0] += 1
nms_inner_loop(ib, j)

if return_indices:
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = batch_size // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = bx * max_threads + tx
with ib.if_scope(i < batch_size):
with ib.for_range(0, valid_count[i]) as j:
idx = box_indices[i * num_anchors + j]
with ib.if_scope(idx >= 0):
box_indices[i * num_anchors + j] = indices[i * num_anchors + idx]
num_valid_boxes[i] = num_valid_boxes_local[0]

with ib.else_scope():
num_valid_boxes[i] = 0

return ib.get()

Expand Down Expand Up @@ -816,13 +745,11 @@ def non_max_suppression(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8
)

indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8)

out, box_indices = te.extern(
[data.shape, score_shape],
out, box_indices, num_valid_boxes = te.extern(
[data.shape, score_shape, [batch_size, 1]],
[data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
Expand All @@ -831,6 +758,7 @@ def non_max_suppression(
ins[3],
outs[0],
outs[1],
outs[2],
max_output_size,
iou_threshold,
force_suppress,
Expand All @@ -840,24 +768,13 @@ def non_max_suppression(
score_index,
return_indices,
),
dtype=[data.dtype, "int32"],
dtype=[data.dtype, "int32", "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
)

if return_indices:
out_shape = box_indices.shape
valid_box_count_shape = [box_indices.shape[0], 1]
valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count")
output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
return te.extern(
[out_shape, valid_box_count_shape],
[box_indices],
lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]),
dtype="int32",
out_buffers=[output, valid_box_count],
name="rearrange_indices_out_gpu",
tag="rearrange_indices_out_gpu",
)
return [box_indices, num_valid_boxes]

return out