diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 75c5c2921ff4..6dcc8580a221 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -21,6 +21,7 @@ from . import stmt as _stmt from . import expr as _expr +from . import op class WithScope(object): @@ -200,6 +201,9 @@ def scope_attr(self, node, attr_key, value): node = _expr.StringImm(node) if isinstance(value, string_types): value = _expr.StringImm(value) + # thread_extent could be zero for dynamic workloads + if attr_key == "thread_extent": + value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 273397071219..cea287edd62e 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -51,68 +51,8 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) -def rearrange_indices_out_ir(data, output, valid_box_count): - """Hybrid routine to rearrange nms output to - move all valid entries to top. - - Parameters - ---------- - data : tvm.te.Tensor or numpy NDArray - NMS output. 3-D tensor with shape - [batch_size, num_anchors, 6] or - [batch_size, num_anchors, 5], or 2-D - tensor with shape [batch_size, num_anchors]. - - one: tvm.tir.const - Constant one with the same dtype as data. - - batch_size: tvm.tir.IntImm or tvm.tir.Var - Batch size. We need to pass it in since hybrid script doesn't support - binding variable to symbolic dim. - - num_anchors: tvm.tir.IntImm or tvm.tir.Var - Number of anchors. - - Returns - ------- - output : tvm.te.Tensor or numpy NDArray - 2-D tensor with shape [batch_size, num_anchors]. - - valid_box_count : tvm.te.Tensor or numpy NDArray - Tensor with shape [batch_size, 1], indicates - the valid number of boxes. - """ - batch_size = data.shape[0] - num_anchors = data.shape[1] - - ib = tvm.tir.ir_builder.create() - - data = ib.buffer_ptr(data) - valid_box_count = ib.buffer_ptr(valid_box_count) - output = ib.buffer_ptr(output) - - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", batch_size) - valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local") - valid_idx[0] = 0 - with ib.for_range(0, num_anchors, name="j") as j: - with ib.if_scope(data[i, j] >= 0): - with ib.if_scope(data[i, j] > num_anchors): - output[i, valid_idx[0]] = 0 - valid_idx[0] = valid_idx[0] + 1 - with ib.else_scope(): - output[i, valid_idx[0]] = data[i, j] - valid_idx[0] = valid_idx[0] + 1 - with ib.else_scope(): - with ib.if_scope(data[i, j] < -num_anchors): - output[i, valid_idx[0]] = 0 - valid_idx[0] = valid_idx[0] + 1 - with ib.if_scope(j >= valid_idx[0]): - output[i, j] = -1 - valid_box_count[i, 0] = valid_idx[0] - - return ib.get() +def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): @@ -400,6 +340,7 @@ def nms_ir( indices, out, box_indices, + num_valid_boxes, max_output_size, iou_threshold, force_suppress, @@ -430,7 +371,15 @@ def nms_ir( is not used before non_max_suppression. out : Buffer - Output buffer. + Output buffer, to be filled with sorted boxes. + + box_indices : Buffer + A indices tensor mapping sorted indices to original indices + This is the first output of NMS when return_indices=True. + + num_valid_boxes : Buffer + Record the number of boxes that have survived IOU tests. + This is the second output of NMS when return_indices=True. max_output_size : int Max number of output valid boxes for each instance. @@ -509,6 +458,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) indices = ib.buffer_ptr(indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) @@ -523,9 +473,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) i = by base_idx = i * num_anchors * box_data_length with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): @@ -533,122 +489,95 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): nkeep = if_then_else( tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] ) - with ib.for_range(0, nkeep) as j: + j = bx * max_threads + tx + with ib.if_scope(j < num_anchors): + box_indices[i * num_anchors + j] = -1 + with ib.if_scope(j < nkeep): + # Fill in out with sorted boxes with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = data[ (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) ] - box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] - with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])): - with ib.for_range(0, valid_count[i] - nkeep) as j: + with ib.else_scope(): + # Indices > nkeep are discarded + with ib.if_scope(j < num_anchors): with ib.for_range(0, box_data_length) as k: - out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 - box_indices[i * num_anchors + (j + nkeep)] = -1 + out[(base_idx + j * box_data_length + k)] = -1.0 + with ib.else_scope(): + with ib.if_scope(j < valid_count[i]): + with ib.for_range(0, box_data_length) as k: + offset = base_idx + j * box_data_length + k + out[offset] = data[offset] + box_indices[i * num_anchors + j] = j + with ib.new_scope(): nthread_by = batch_size by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) i = by base_idx = i * num_anchors * box_data_length + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + num_valid_boxes_local[0] = 0 + + def nms_inner_loop(ib, j): + offset_j = j * box_data_length + + with ib.for_range(0, j) as k: + offset_k = k * box_data_length + + with ib.if_scope( + tvm.tir.all( + out[base_idx + offset_j + score_index] > -1.0, # if already surpressed + out[base_idx + offset_k + score_index] > 0, + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] + == out[base_idx + offset_j + id_index], + ), + ) + ): + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_j + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_j + id_index] = -1.0 + + # Has the box j survived IOU tests? + with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): + # When return_indices is False, no need to populate box_indices + if return_indices: + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + num_valid_boxes_local[0] += 1 + + if isinstance(max_output_size, int): + max_output_size = tvm.tir.const(max_output_size) + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, valid_count[i]) as j: - with ib.for_range(0, j) as k: - offset_k = k * box_data_length - with ib.if_scope( - tvm.tir.all( - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), - ) - ): - offset_j = j * box_data_length - with ib.if_scope( - tvm.tir.all( - j > k, - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], - ), - ) - ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_j + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_j + id_index] = -1.0 - box_indices[i * num_anchors + j] = -1 - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 - nthread_by = batch_size - nthread_bz = box_data_length - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) - tid = bx * max_threads + tx - i = by - j = tid - k = bz - base_idx = i * num_anchors * box_data_length - with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - pass - with ib.else_scope(): - with ib.if_scope(j < valid_count[i]): - offset_j = j * box_data_length - out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] - box_indices[i * num_anchors + j] = j - - with ib.new_scope(): - num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - i = bx - base_idx = i * num_anchors * box_data_length - # Set invalid entry to be -1 - with ib.for_range(0, num_anchors - valid_count[i]) as j: - with ib.for_range(0, box_data_length) as k: - out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 - box_indices[i * num_anchors + j + valid_count[i]] = -1 - # Only return max_output_size number of valid boxes - num_valid_boxes[0] = 0 - with ib.if_scope(max_output_size > 0): - with ib.for_range(0, valid_count[i]) as j: - offset_j = j * box_data_length - with ib.if_scope(out[base_idx + offset_j] >= 0): - with ib.if_scope(num_valid_boxes[0] == max_output_size): - with ib.for_range(0, box_data_length) as k: - out[base_idx + offset_j + k] = -1.0 - box_indices[i * num_anchors + j] = -1 + with ib.if_scope( + tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0) + ): + with ib.if_scope(max_output_size > 0): + # No need to do more iteration if we already reach max_output_size boxes + with ib.if_scope(num_valid_boxes_local[0] < max_output_size): + nms_inner_loop(ib, j) with ib.else_scope(): - num_valid_boxes[0] += 1 + nms_inner_loop(ib, j) - if return_indices: - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = batch_size // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - i = bx * max_threads + tx - with ib.if_scope(i < batch_size): - with ib.for_range(0, valid_count[i]) as j: - idx = box_indices[i * num_anchors + j] - with ib.if_scope(idx >= 0): - box_indices[i * num_anchors + j] = indices[i * num_anchors + idx] + num_valid_boxes[i] = num_valid_boxes_local[0] + + with ib.else_scope(): + num_valid_boxes[i] = 0 return ib.get() @@ -816,13 +745,11 @@ def non_max_suppression( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 ) - indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - out, box_indices = te.extern( - [data.shape, score_shape], + out, box_indices, num_valid_boxes = te.extern( + [data.shape, score_shape, [batch_size, 1]], [data, sort_tensor, valid_count, indices], lambda ins, outs: nms_ir( ins[0], @@ -831,6 +758,7 @@ def non_max_suppression( ins[3], outs[0], outs[1], + outs[2], max_output_size, iou_threshold, force_suppress, @@ -840,24 +768,13 @@ def non_max_suppression( score_index, return_indices, ), - dtype=[data.dtype, "int32"], + dtype=[data.dtype, "int32", "int32"], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], name="nms", tag="nms", ) + if return_indices: - out_shape = box_indices.shape - valid_box_count_shape = [box_indices.shape[0], 1] - valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count") - output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output") - return te.extern( - [out_shape, valid_box_count_shape], - [box_indices], - lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]), - dtype="int32", - out_buffers=[output, valid_box_count], - name="rearrange_indices_out_gpu", - tag="rearrange_indices_out_gpu", - ) + return [box_indices, num_valid_boxes] return out