From ff3259d00aac32a7ecd2701cf1c8406901ba7094 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 16:32:36 +0900 Subject: [PATCH 01/22] remove get_valid_counts from pytorch nms --- python/tvm/relay/frontend/pytorch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c75bd2dd3c09..2689b12c1946 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1857,16 +1857,16 @@ def nms(self, inputs, input_types): scores = inputs[1] iou_threshold = inputs[2] + num_boxes = _op.shape_of(scores) + # Generate data with shape (1, num_anchors, 5) scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {}) - - # Prepare input data for get_valid_counts data = _op.concatenate([scores, boxes], -1) data = _op.expand_dims(data, 0, 1) - # Leverage get_valid_counts to sort the data and clear invalid boxes - ct, data, indices = get_relay_op("get_valid_counts")( - data, score_threshold=-1.0, id_index=-1, score_index=0 - ) + # PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count + indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32") + indices = _op.expand_dims(indices, 0, 1) + ct = num_boxes # Perform Non-Maximum Suppression, # PyTorch NMS doesn't have parameter top_k and max_output_size From d142e5e273901887d1e19982b33a6dd692f70936 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 17:17:32 +0900 Subject: [PATCH 02/22] fix pytorch nms for negative score --- python/tvm/relay/frontend/pytorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2689b12c1946..94ee9282e4fa 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1859,6 +1859,8 @@ def nms(self, inputs, input_types): num_boxes = _op.shape_of(scores) + # TVM NMS assumes score > 0 + scores = scores - _op.min(scores) + _op.const(1.0) # Generate data with shape (1, num_anchors, 5) scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {}) data = _op.concatenate([scores, boxes], -1) From 8a85d57a240adb53ee86f5a0d578e61601e3a553 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 17:25:08 +0900 Subject: [PATCH 03/22] merge reset by -1 --- python/tvm/topi/cuda/nms.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 273397071219..2c9923e392a0 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -539,11 +539,10 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) ] box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] - with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])): - with ib.for_range(0, valid_count[i] - nkeep) as j: - with ib.for_range(0, box_data_length) as k: - out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 - box_indices[i * num_anchors + (j + nkeep)] = -1 + with ib.for_range(0, num_anchors - nkeep) as j: + with ib.for_range(0, box_data_length) as k: + out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + (j + nkeep)] = -1 with ib.new_scope(): nthread_by = batch_size by = te.thread_axis("blockIdx.y") @@ -617,11 +616,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ib.scope_attr(bx, "thread_extent", batch_size) i = bx base_idx = i * num_anchors * box_data_length - # Set invalid entry to be -1 - with ib.for_range(0, num_anchors - valid_count[i]) as j: - with ib.for_range(0, box_data_length) as k: - out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 - box_indices[i * num_anchors + j + valid_count[i]] = -1 # Only return max_output_size number of valid boxes num_valid_boxes[0] = 0 with ib.if_scope(max_output_size > 0): From 7543e063bf601edd087488079b56c7a4d69c671b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 17:49:31 +0900 Subject: [PATCH 04/22] move max_out_size handling to triangle loop --- python/tvm/topi/cuda/nms.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2c9923e392a0..5c79e85cedca 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -543,12 +543,19 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.for_range(0, box_data_length) as k: out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 box_indices[i * num_anchors + (j + nkeep)] = -1 + with ib.new_scope(): nthread_by = batch_size by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) i = by base_idx = i * num_anchors * box_data_length + num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") + num_valid_boxes[0] = 0 + + with ib.if_scope(max_output_size == 0): + max_output_size = valid_count[i] + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, valid_count[i]) as j: @@ -556,6 +563,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): offset_k = k * box_data_length with ib.if_scope( tvm.tir.all( + num_valid_boxes[0] < max_output_size, out[base_idx + offset_k + score_index] > 0, tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), ) @@ -584,6 +592,10 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(id_index >= 0): out[base_idx + offset_j + id_index] = -1.0 box_indices[i * num_anchors + j] = -1 + + with ib.if_scope(box_indices[i * num_anchors + j] != -1): + num_valid_boxes[0] += 1 + with ib.new_scope(): nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 @@ -610,25 +622,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] box_indices[i * num_anchors + j] = j - with ib.new_scope(): - num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - i = bx - base_idx = i * num_anchors * box_data_length - # Only return max_output_size number of valid boxes - num_valid_boxes[0] = 0 - with ib.if_scope(max_output_size > 0): - with ib.for_range(0, valid_count[i]) as j: - offset_j = j * box_data_length - with ib.if_scope(out[base_idx + offset_j] >= 0): - with ib.if_scope(num_valid_boxes[0] == max_output_size): - with ib.for_range(0, box_data_length) as k: - out[base_idx + offset_j + k] = -1.0 - box_indices[i * num_anchors + j] = -1 - with ib.else_scope(): - num_valid_boxes[0] += 1 - if return_indices: with ib.new_scope(): nthread_tx = max_threads From cb5663af98bc53d7f5147e0bd4bcaa2988985f99 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 17:52:35 +0900 Subject: [PATCH 05/22] update torch nms test --- tests/python/frontend/pytorch/test_forward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2dda675c74f5..18169f31cd00 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1675,10 +1675,10 @@ def _gen_rand_inputs(num_boxes): boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5 boxes[:, 2] += boxes[:, 0] boxes[:, 3] += boxes[:, 1] - scores = torch.rand(num_boxes, dtype=torch.float) + scores = torch.from_numpy(np.random.uniform(-1, 1, size=(num_boxes,)).astype(np.float32)) return boxes, scores - targets = ["llvm"] # dynamic nms does not work on gpu + targets = ["llvm", "cuda"] for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: in_boxes, in_scores = _gen_rand_inputs(num_boxes) From c419097977cf0bb3bf7e84d5db382a7e93dbe638 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 18:11:12 +0900 Subject: [PATCH 06/22] fuse the last two kernels --- python/tvm/topi/cuda/nms.py | 81 ++++++++++++------------------------- 1 file changed, 25 insertions(+), 56 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 5c79e85cedca..b11041c3cbbe 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -51,27 +51,22 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) -def rearrange_indices_out_ir(data, output, valid_box_count): - """Hybrid routine to rearrange nms output to - move all valid entries to top. +def rearrange_indices_out_ir(nms_box_indices, orig_indices, output, valid_box_count): + """Compact and remap sorted indices to original indices. Parameters ---------- - data : tvm.te.Tensor or numpy NDArray - NMS output. 3-D tensor with shape - [batch_size, num_anchors, 6] or - [batch_size, num_anchors, 5], or 2-D - tensor with shape [batch_size, num_anchors]. + nms_box_indices : tvm.te.Tensor or numpy NDArray + NMS output with return_indices=True + Tensor with shape [batch_size, num_anchors]. + Each row is indices sorted by box scores. - one: tvm.tir.const - Constant one with the same dtype as data. - - batch_size: tvm.tir.IntImm or tvm.tir.Var - Batch size. We need to pass it in since hybrid script doesn't support - binding variable to symbolic dim. - - num_anchors: tvm.tir.IntImm or tvm.tir.Var - Number of anchors. + orig_indices : Buffer + indices in original tensor, with shape [batch_size, num_anchors], + represents the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the second + dimension are like the output of arange(num_anchors) if get_valid_counts + is not used before non_max_suppression. Returns ------- @@ -82,12 +77,13 @@ def rearrange_indices_out_ir(data, output, valid_box_count): Tensor with shape [batch_size, 1], indicates the valid number of boxes. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] + batch_size = nms_box_indices.shape[0] + num_anchors = nms_box_indices.shape[1] ib = tvm.tir.ir_builder.create() - data = ib.buffer_ptr(data) + nms_box_indices = ib.buffer_ptr(nms_box_indices) + orig_indices = ib.buffer_ptr(orig_indices) valid_box_count = ib.buffer_ptr(valid_box_count) output = ib.buffer_ptr(output) @@ -96,16 +92,17 @@ def rearrange_indices_out_ir(data, output, valid_box_count): ib.scope_attr(i, "thread_extent", batch_size) valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local") valid_idx[0] = 0 + # TODO(masahi): Use execlusive scan here with ib.for_range(0, num_anchors, name="j") as j: - with ib.if_scope(data[i, j] >= 0): - with ib.if_scope(data[i, j] > num_anchors): + with ib.if_scope(nms_box_indices[i, j] >= 0): + with ib.if_scope(nms_box_indices[i, j] > num_anchors): output[i, valid_idx[0]] = 0 valid_idx[0] = valid_idx[0] + 1 with ib.else_scope(): - output[i, valid_idx[0]] = data[i, j] + output[i, valid_idx[0]] = orig_indices[i, nms_box_indices[i, j]] valid_idx[0] = valid_idx[0] + 1 with ib.else_scope(): - with ib.if_scope(data[i, j] < -num_anchors): + with ib.if_scope(nms_box_indices[i, j] < -num_anchors): output[i, valid_idx[0]] = 0 valid_idx[0] = valid_idx[0] + 1 with ib.if_scope(j >= valid_idx[0]): @@ -397,7 +394,6 @@ def nms_ir( data, sorted_index, valid_count, - indices, out, box_indices, max_output_size, @@ -422,13 +418,6 @@ def nms_ir( valid_count : Buffer Buffer of number of valid output boxes. - indices : Buffer - indices in original tensor, with shape [batch_size, num_anchors], - represents the index of box in original data. It could be the third - output out_indices of get_valid_counts. The values in the second - dimension are like the output of arange(num_anchors) if get_valid_counts - is not used before non_max_suppression. - out : Buffer Output buffer. @@ -508,7 +497,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): data = ib.buffer_ptr(data) sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) - indices = ib.buffer_ptr(indices) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) @@ -622,21 +610,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] box_indices[i * num_anchors + j] = j - if return_indices: - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = batch_size // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - i = bx * max_threads + tx - with ib.if_scope(i < batch_size): - with ib.for_range(0, valid_count[i]) as j: - idx = box_indices[i * num_anchors + j] - with ib.if_scope(idx >= 0): - box_indices[i * num_anchors + j] = indices[i * num_anchors + idx] - return ib.get() @@ -803,19 +776,15 @@ def non_max_suppression( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 ) - indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) out, box_indices = te.extern( [data.shape, score_shape], - [data, sort_tensor, valid_count, indices], + [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], - ins[3], outs[0], outs[1], max_output_size, @@ -828,7 +797,7 @@ def non_max_suppression( return_indices, ), dtype=[data.dtype, "int32"], - in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], + in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], name="nms", tag="nms", ) @@ -839,8 +808,8 @@ def non_max_suppression( output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output") return te.extern( [out_shape, valid_box_count_shape], - [box_indices], - lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]), + [box_indices, indices], + lambda ins, outs: rearrange_indices_out_ir(ins[0], ins[1], outs[0], outs[1]), dtype="int32", out_buffers=[output, valid_box_count], name="rearrange_indices_out_gpu", From 03fabdbb03bbeef2914fbac64ae9ac1dc6c8b14e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 18:27:46 +0900 Subject: [PATCH 07/22] parallelize the first kernel --- python/tvm/topi/cuda/nms.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index b11041c3cbbe..38b263af2012 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -51,6 +51,11 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) + +def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) + + def rearrange_indices_out_ir(nms_box_indices, orig_indices, output, valid_box_count): """Compact and remap sorted indices to original indices. @@ -511,9 +516,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) i = by base_idx = i * num_anchors * box_data_length with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): @@ -521,16 +532,18 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): nkeep = if_then_else( tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] ) - with ib.for_range(0, nkeep) as j: + j = bx * max_threads + tx + with ib.if_scope(j < nkeep): with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = data[ (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) ] box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] - with ib.for_range(0, num_anchors - nkeep) as j: - with ib.for_range(0, box_data_length) as k: - out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 - box_indices[i * num_anchors + (j + nkeep)] = -1 + with ib.else_scope(): + with ib.if_scope(j < num_anchors): + with ib.for_range(0, box_data_length) as k: + out[(base_idx + j * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + j] = -1 with ib.new_scope(): nthread_by = batch_size From d87cbf856662c1170e90a69f247550553d82bb0b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 18:36:20 +0900 Subject: [PATCH 08/22] merge first and last kernel --- python/tvm/topi/cuda/nms.py | 34 +++++++--------------------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 38b263af2012..71f885e62de0 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -51,7 +51,6 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) - def ceil_div(a, b): return tvm.tir.indexdiv(a + b - 1, b) @@ -540,10 +539,17 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ] box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] with ib.else_scope(): + # Indices > nkeep are discarded with ib.if_scope(j < num_anchors): with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = -1.0 box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): + with ib.if_scope(j < valid_count[i]): + with ib.for_range(0, box_data_length) as k: + offset = base_idx + j * box_data_length + k + out[offset] = data[offset] + box_indices[i * num_anchors + j] = j with ib.new_scope(): nthread_by = batch_size @@ -597,32 +603,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(box_indices[i * num_anchors + j] != -1): num_valid_boxes[0] += 1 - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 - nthread_by = batch_size - nthread_bz = box_data_length - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) - tid = bx * max_threads + tx - i = by - j = tid - k = bz - base_idx = i * num_anchors * box_data_length - with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - pass - with ib.else_scope(): - with ib.if_scope(j < valid_count[i]): - offset_j = j * box_data_length - out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] - box_indices[i * num_anchors + j] = j - return ib.get() From 8f552bdd6e2e49cbd0088440aba4dfa963368b92 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 18:41:09 +0900 Subject: [PATCH 09/22] remove unnecessary cases --- python/tvm/topi/cuda/nms.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 71f885e62de0..d241cbda997b 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -99,16 +99,8 @@ def rearrange_indices_out_ir(nms_box_indices, orig_indices, output, valid_box_co # TODO(masahi): Use execlusive scan here with ib.for_range(0, num_anchors, name="j") as j: with ib.if_scope(nms_box_indices[i, j] >= 0): - with ib.if_scope(nms_box_indices[i, j] > num_anchors): - output[i, valid_idx[0]] = 0 - valid_idx[0] = valid_idx[0] + 1 - with ib.else_scope(): - output[i, valid_idx[0]] = orig_indices[i, nms_box_indices[i, j]] - valid_idx[0] = valid_idx[0] + 1 - with ib.else_scope(): - with ib.if_scope(nms_box_indices[i, j] < -num_anchors): - output[i, valid_idx[0]] = 0 - valid_idx[0] = valid_idx[0] + 1 + output[i, valid_idx[0]] = orig_indices[i, nms_box_indices[i, j]] + valid_idx[0] = valid_idx[0] + 1 with ib.if_scope(j >= valid_idx[0]): output[i, j] = -1 valid_box_count[i, 0] = valid_idx[0] From e31345e85620e7ac0326488ebadfe7eff5485745 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 18:41:44 +0900 Subject: [PATCH 10/22] fix typo --- python/tvm/topi/cuda/nms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index d241cbda997b..8e20e571c546 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -96,7 +96,7 @@ def rearrange_indices_out_ir(nms_box_indices, orig_indices, output, valid_box_co ib.scope_attr(i, "thread_extent", batch_size) valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local") valid_idx[0] = 0 - # TODO(masahi): Use execlusive scan here + # TODO(masahi): Use exclusive scan here with ib.for_range(0, num_anchors, name="j") as j: with ib.if_scope(nms_box_indices[i, j] >= 0): output[i, valid_idx[0]] = orig_indices[i, nms_box_indices[i, j]] From 53caa84ba77661d43b427583427c9cae184c1cc9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 19:01:48 +0900 Subject: [PATCH 11/22] revert pytorch frontend change --- python/tvm/relay/frontend/pytorch.py | 14 ++++++-------- tests/python/frontend/pytorch/test_forward.py | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 94ee9282e4fa..c75bd2dd3c09 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1857,18 +1857,16 @@ def nms(self, inputs, input_types): scores = inputs[1] iou_threshold = inputs[2] - num_boxes = _op.shape_of(scores) - - # TVM NMS assumes score > 0 - scores = scores - _op.min(scores) + _op.const(1.0) # Generate data with shape (1, num_anchors, 5) scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {}) + + # Prepare input data for get_valid_counts data = _op.concatenate([scores, boxes], -1) data = _op.expand_dims(data, 0, 1) - # PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count - indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32") - indices = _op.expand_dims(indices, 0, 1) - ct = num_boxes + # Leverage get_valid_counts to sort the data and clear invalid boxes + ct, data, indices = get_relay_op("get_valid_counts")( + data, score_threshold=-1.0, id_index=-1, score_index=0 + ) # Perform Non-Maximum Suppression, # PyTorch NMS doesn't have parameter top_k and max_output_size diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 18169f31cd00..2dda675c74f5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1675,10 +1675,10 @@ def _gen_rand_inputs(num_boxes): boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5 boxes[:, 2] += boxes[:, 0] boxes[:, 3] += boxes[:, 1] - scores = torch.from_numpy(np.random.uniform(-1, 1, size=(num_boxes,)).astype(np.float32)) + scores = torch.rand(num_boxes, dtype=torch.float) return boxes, scores - targets = ["llvm", "cuda"] + targets = ["llvm"] # dynamic nms does not work on gpu for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: in_boxes, in_scores = _gen_rand_inputs(num_boxes) From 878e905430c3891bd1891dc7b94ec0f4f654e0e1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 21:44:14 +0900 Subject: [PATCH 12/22] fuse rearrange step with triangle loop --- python/tvm/topi/cuda/nms.py | 116 ++++++++++++------------------------ 1 file changed, 37 insertions(+), 79 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 8e20e571c546..802e86eb03e2 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -55,59 +55,6 @@ def ceil_div(a, b): return tvm.tir.indexdiv(a + b - 1, b) -def rearrange_indices_out_ir(nms_box_indices, orig_indices, output, valid_box_count): - """Compact and remap sorted indices to original indices. - - Parameters - ---------- - nms_box_indices : tvm.te.Tensor or numpy NDArray - NMS output with return_indices=True - Tensor with shape [batch_size, num_anchors]. - Each row is indices sorted by box scores. - - orig_indices : Buffer - indices in original tensor, with shape [batch_size, num_anchors], - represents the index of box in original data. It could be the third - output out_indices of get_valid_counts. The values in the second - dimension are like the output of arange(num_anchors) if get_valid_counts - is not used before non_max_suppression. - - Returns - ------- - output : tvm.te.Tensor or numpy NDArray - 2-D tensor with shape [batch_size, num_anchors]. - - valid_box_count : tvm.te.Tensor or numpy NDArray - Tensor with shape [batch_size, 1], indicates - the valid number of boxes. - """ - batch_size = nms_box_indices.shape[0] - num_anchors = nms_box_indices.shape[1] - - ib = tvm.tir.ir_builder.create() - - nms_box_indices = ib.buffer_ptr(nms_box_indices) - orig_indices = ib.buffer_ptr(orig_indices) - valid_box_count = ib.buffer_ptr(valid_box_count) - output = ib.buffer_ptr(output) - - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", batch_size) - valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local") - valid_idx[0] = 0 - # TODO(masahi): Use exclusive scan here - with ib.for_range(0, num_anchors, name="j") as j: - with ib.if_scope(nms_box_indices[i, j] >= 0): - output[i, valid_idx[0]] = orig_indices[i, nms_box_indices[i, j]] - valid_idx[0] = valid_idx[0] + 1 - with ib.if_scope(j >= valid_idx[0]): - output[i, j] = -1 - valid_box_count[i, 0] = valid_idx[0] - - return ib.get() - - def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): """Low level IR to identify bounding boxes given a score threshold. @@ -390,8 +337,10 @@ def nms_ir( data, sorted_index, valid_count, + indices, out, box_indices, + num_valid_boxes, max_output_size, iou_threshold, force_suppress, @@ -414,6 +363,13 @@ def nms_ir( valid_count : Buffer Buffer of number of valid output boxes. + indices : Buffer + indices in original tensor, with shape [batch_size, num_anchors], + represents the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the second + dimension are like the output of arange(num_anchors) if get_valid_counts + is not used before non_max_suppression. + out : Buffer Output buffer. @@ -493,6 +449,8 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): data = ib.buffer_ptr(data) sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) + indices = ib.buffer_ptr(indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) @@ -524,18 +482,18 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] ) j = bx * max_threads + tx + with ib.if_scope(j < num_anchors): + box_indices[i * num_anchors + j] = -1 with ib.if_scope(j < nkeep): with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = data[ (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) ] - box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] with ib.else_scope(): # Indices > nkeep are discarded with ib.if_scope(j < num_anchors): with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = -1.0 - box_indices[i * num_anchors + j] = -1 with ib.else_scope(): with ib.if_scope(j < valid_count[i]): with ib.for_range(0, box_data_length) as k: @@ -549,8 +507,10 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ib.scope_attr(by, "thread_extent", nthread_by) i = by base_idx = i * num_anchors * box_data_length - num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") - num_valid_boxes[0] = 0 + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + num_valid_boxes_local[0] = 0 with ib.if_scope(max_output_size == 0): max_output_size = valid_count[i] @@ -562,7 +522,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): offset_k = k * box_data_length with ib.if_scope( tvm.tir.all( - num_valid_boxes[0] < max_output_size, + num_valid_boxes_local[0] < max_output_size, out[base_idx + offset_k + score_index] > 0, tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), ) @@ -590,10 +550,16 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): out[base_idx + offset_j + score_index] = -1.0 with ib.if_scope(id_index >= 0): out[base_idx + offset_j + id_index] = -1.0 - box_indices[i * num_anchors + j] = -1 - with ib.if_scope(box_indices[i * num_anchors + j] != -1): - num_valid_boxes[0] += 1 + with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): + if return_indices: + box_indices[i * num_anchors + num_valid_boxes_local[0]] = indices[ + i, sorted_index[i * num_anchors + j] + ] + # box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] + num_valid_boxes_local[0] += 1 + + num_valid_boxes[i] = num_valid_boxes_local[0] return ib.get() @@ -762,16 +728,19 @@ def non_max_suppression( ) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - out, box_indices = te.extern( - [data.shape, score_shape], - [data, sort_tensor, valid_count], + out, box_indices, num_valid_boxes = te.extern( + [data.shape, score_shape, [batch_size, 1]], + [data, sort_tensor, valid_count, indices], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], + ins[3], outs[0], outs[1], + outs[2], max_output_size, iou_threshold, force_suppress, @@ -781,24 +750,13 @@ def non_max_suppression( score_index, return_indices, ), - dtype=[data.dtype, "int32"], - in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], + dtype=[data.dtype, "int32", "int32"], + in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf], name="nms", tag="nms", ) + if return_indices: - out_shape = box_indices.shape - valid_box_count_shape = [box_indices.shape[0], 1] - valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count") - output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output") - return te.extern( - [out_shape, valid_box_count_shape], - [box_indices, indices], - lambda ins, outs: rearrange_indices_out_ir(ins[0], ins[1], outs[0], outs[1]), - dtype="int32", - out_buffers=[output, valid_box_count], - name="rearrange_indices_out_gpu", - tag="rearrange_indices_out_gpu", - ) + return [box_indices, num_valid_boxes] return out From 0442eceac81504b37c7fd39ed8e435accaaeabd3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 22:35:15 +0900 Subject: [PATCH 13/22] fix max_output_size handling --- python/tvm/topi/cuda/nms.py | 83 +++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 802e86eb03e2..26d302829a14 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -512,52 +512,53 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ) num_valid_boxes_local[0] = 0 - with ib.if_scope(max_output_size == 0): - max_output_size = valid_count[i] - - with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - # Apply nms - with ib.for_range(0, valid_count[i]) as j: - with ib.for_range(0, j) as k: - offset_k = k * box_data_length + def nms_loop(ib, j): + with ib.for_range(0, j) as k: + offset_k = k * box_data_length + with ib.if_scope( + tvm.tir.all( + out[base_idx + offset_k + score_index] > 0, + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + ) + ): + offset_j = j * box_data_length with ib.if_scope( tvm.tir.all( - num_valid_boxes_local[0] < max_output_size, + j > k, out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] + == out[base_idx + offset_j + id_index], + ), ) ): - offset_j = j * box_data_length - with ib.if_scope( - tvm.tir.all( - j > k, - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], - ), - ) - ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_j + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_j + id_index] = -1.0 - - with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): - if return_indices: - box_indices[i * num_anchors + num_valid_boxes_local[0]] = indices[ - i, sorted_index[i * num_anchors + j] - ] - # box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] - num_valid_boxes_local[0] += 1 + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_j + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_j + id_index] = -1.0 + + with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): + if return_indices: + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] + num_valid_boxes_local[0] += 1 + + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + # Apply nms + with ib.for_range(0, valid_count[i]) as j: + with ib.if_scope(max_output_size > 0): + with ib.if_scope(num_valid_boxes_local[0] < max_output_size): + nms_loop(ib, j) + with ib.else_scope(): + nms_loop(ib, j) num_valid_boxes[i] = num_valid_boxes_local[0] From d6f634ae5a494ab3ec2373da04270d69b63dfdcf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 22:51:21 +0900 Subject: [PATCH 14/22] check if already surpressed --- python/tvm/topi/cuda/nms.py | 57 ++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 26d302829a14..c71ea4d728a2 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -512,38 +512,34 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ) num_valid_boxes_local[0] = 0 - def nms_loop(ib, j): + def nms_inner_loop(ib, j): + offset_j = j * box_data_length + with ib.for_range(0, j) as k: offset_k = k * box_data_length + with ib.if_scope( tvm.tir.all( + out[base_idx + offset_j + score_index] > -1.0, # if already surpressed out[base_idx + offset_k + score_index] > 0, tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + tvm.tir.any( + force_suppress > 0, + id_index < 0, + out[base_idx + offset_k + id_index] + == out[base_idx + offset_j + id_index], + ), ) ): - offset_j = j * box_data_length - with ib.if_scope( - tvm.tir.all( - j > k, - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0), - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out[base_idx + offset_k + id_index] - == out[base_idx + offset_j + id_index], - ), - ) - ): - iou = calculate_overlap( - out, - base_idx + offset_j + coord_start, - base_idx + offset_k + coord_start, - ) - with ib.if_scope(iou >= iou_threshold): - out[base_idx + offset_j + score_index] = -1.0 - with ib.if_scope(id_index >= 0): - out[base_idx + offset_j + id_index] = -1.0 + iou = calculate_overlap( + out, + base_idx + offset_j + coord_start, + base_idx + offset_k + coord_start, + ) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_j + score_index] = -1.0 + with ib.if_scope(id_index >= 0): + out[base_idx + offset_j + id_index] = -1.0 with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): if return_indices: @@ -554,11 +550,14 @@ def nms_loop(ib, j): with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, valid_count[i]) as j: - with ib.if_scope(max_output_size > 0): - with ib.if_scope(num_valid_boxes_local[0] < max_output_size): - nms_loop(ib, j) - with ib.else_scope(): - nms_loop(ib, j) + with ib.if_scope( + tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0) + ): + with ib.if_scope(max_output_size > 0): + with ib.if_scope(num_valid_boxes_local[0] < max_output_size): + nms_inner_loop(ib, j) + with ib.else_scope(): + nms_inner_loop(ib, j) num_valid_boxes[i] = num_valid_boxes_local[0] From fc3052669ecd4b21c5dea72d3fe79e1d7a01c1a8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 07:30:06 +0900 Subject: [PATCH 15/22] fix topi vision test by wrapping tir const around int argument --- python/tvm/topi/cuda/nms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index c71ea4d728a2..8a372f2cabd0 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -547,6 +547,9 @@ def nms_inner_loop(ib, j): box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] num_valid_boxes_local[0] += 1 + if isinstance(max_output_size, int): + max_output_size = tvm.tir.const(max_output_size) + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms with ib.for_range(0, valid_count[i]) as j: From 1c6fb6ed2df3f1a91d2f61f9391f95a6b155d906 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 10:08:32 +0900 Subject: [PATCH 16/22] fix for num anchors = 0 case --- python/tvm/topi/cuda/nms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 8a372f2cabd0..638a6b63946b 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -466,7 +466,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) + nthread_bx = tvm.tir.max(1, ceil_div(num_anchors, max_threads)) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") From 51a7926ccd73c5e0a88f57dfeee6091665251ffc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 10:51:18 +0900 Subject: [PATCH 17/22] fix missing zero init of num valid boxes when the input is empty --- python/tvm/topi/cuda/nms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 638a6b63946b..1fbabc1c4c77 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -563,6 +563,8 @@ def nms_inner_loop(ib, j): nms_inner_loop(ib, j) num_valid_boxes[i] = num_valid_boxes_local[0] + with ib.else_scope(): + num_valid_boxes[i] = 0 return ib.get() From 622a87614ea0d517a97c41a9b715c0fd2fb206bf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 11:05:31 +0900 Subject: [PATCH 18/22] add some comments and missing doc --- python/tvm/topi/cuda/nms.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 1fbabc1c4c77..c4275732ec71 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -371,7 +371,15 @@ def nms_ir( is not used before non_max_suppression. out : Buffer - Output buffer. + Output buffer, to be filled with sorted boxes. + + box_indices : Buffer + A indices tensor mapping sorted indices to original indices + This is the first output of NMS when return_indices=True. + + num_valid_boxes : Buffer + Record the number of boxes that have survived IOU tests. + This is the second output of NMS when return_indices=True. max_output_size : int Max number of output valid boxes for each instance. @@ -466,6 +474,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_tx = max_threads + # num_anchors can be zero nthread_bx = tvm.tir.max(1, ceil_div(num_anchors, max_threads)) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") @@ -485,6 +494,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(j < num_anchors): box_indices[i * num_anchors + j] = -1 with ib.if_scope(j < nkeep): + # Fill in out with sorted boxes with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = data[ (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) @@ -541,7 +551,9 @@ def nms_inner_loop(ib, j): with ib.if_scope(id_index >= 0): out[base_idx + offset_j + id_index] = -1.0 + # Does the box j has survived IOU tests? with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): + # When return_indices is False, no need to populate box_indices if return_indices: orig_idx = sorted_index[i * num_anchors + j] box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] @@ -557,12 +569,14 @@ def nms_inner_loop(ib, j): tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0) ): with ib.if_scope(max_output_size > 0): + # No need to do more iteration if we alread reach max_output_size boxes with ib.if_scope(num_valid_boxes_local[0] < max_output_size): nms_inner_loop(ib, j) with ib.else_scope(): nms_inner_loop(ib, j) num_valid_boxes[i] = num_valid_boxes_local[0] + with ib.else_scope(): num_valid_boxes[i] = 0 From 6311dcb20738eb82cd33e9b8713e1ecbbe5a6a2f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 11:13:07 +0900 Subject: [PATCH 19/22] typo fix --- python/tvm/topi/cuda/nms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index c4275732ec71..2bc2145a4cb1 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -569,7 +569,7 @@ def nms_inner_loop(ib, j): tvm.tir.any(id_index < 0, out[base_idx + j * box_data_length + id_index] >= 0) ): with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we alread reach max_output_size boxes + # No need to do more iteration if we already reach max_output_size boxes with ib.if_scope(num_valid_boxes_local[0] < max_output_size): nms_inner_loop(ib, j) with ib.else_scope(): From 591a4f4e295aad3eb39367127ff014a2ba4fb231 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 13:09:44 +0900 Subject: [PATCH 20/22] add a guard against zero dim grid / thread block inside ir_buidlder --- python/tvm/tir/ir_builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 75c5c2921ff4..6dcc8580a221 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -21,6 +21,7 @@ from . import stmt as _stmt from . import expr as _expr +from . import op class WithScope(object): @@ -200,6 +201,9 @@ def scope_attr(self, node, attr_key, value): node = _expr.StringImm(node) if isinstance(value, string_types): value = _expr.StringImm(value) + # thread_extent could be zero for dynamic workloads + if attr_key == "thread_extent": + value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): From 9b08052cc37cbba40b92ef0db3f55bc0f43bb546 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 13:10:41 +0900 Subject: [PATCH 21/22] typo fix --- python/tvm/topi/cuda/nms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2bc2145a4cb1..08290b6fbca8 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -475,7 +475,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_tx = max_threads # num_anchors can be zero - nthread_bx = tvm.tir.max(1, ceil_div(num_anchors, max_threads)) + nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -551,7 +551,7 @@ def nms_inner_loop(ib, j): with ib.if_scope(id_index >= 0): out[base_idx + offset_j + id_index] = -1.0 - # Does the box j has survived IOU tests? + # Has the box j survived IOU tests? with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0): # When return_indices is False, no need to populate box_indices if return_indices: From 878c9cb04015295931bf13c919d378f0c7e19f92 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 21 Dec 2020 17:40:46 +0900 Subject: [PATCH 22/22] trigger CI --- python/tvm/topi/cuda/nms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 08290b6fbca8..cea287edd62e 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -474,7 +474,6 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.new_scope(): nthread_tx = max_threads - # num_anchors can be zero nthread_bx = ceil_div(num_anchors, max_threads) nthread_by = batch_size tx = te.thread_axis("threadIdx.x")