diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 2b905f5bd04b..ca2c4a2b837d 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -73,14 +73,12 @@ struct MultiBoxTransformLocAttrs : public tvm::AttrsNode { - double score_threshold; + Optional score_threshold; int id_index; int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { - TVM_ATTR_FIELD(score_threshold) - .set_default(0.0) - .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(score_threshold).describe("Lower limit of score for valid bounding boxes."); TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id."); TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes."); } @@ -89,7 +87,7 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in non_maximum_suppression operator */ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { Optional max_output_size; - double iou_threshold; + Optional iou_threshold; bool force_suppress; int top_k; int coord_start; @@ -100,9 +98,7 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode= 0): + with ib.if_scope(data[i, j] > num_anchors): + output[i, valid_idx[0]] = 0 + valid_idx[0] = valid_idx[0] + 1 + with ib.else_scope(): + output[i, valid_idx[0]] = data[i, j] + valid_idx[0] = valid_idx[0] + 1 + with ib.else_scope(): + with ib.if_scope(data[i, j] < -num_anchors): + output[i, valid_idx[0]] = 0 + valid_idx[0] = valid_idx[0] + 1 + with ib.if_scope(j >= valid_idx[0]): + output[i, j] = -1 + valid_box_count[i, 0] = valid_idx[0] - one_count = tvm.tir.const(1, dtype="int32") - atomic_add_return = ib.allocate( - valid_box_count.dtype, (batch_size,), name="atomic_add_return", scope="local" - ) + return ib.get() + + +def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): + """Low level IR to identify bounding boxes given a score threshold. + + Parameters + ---------- + data : Buffer + Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. + + score_threshold : Buffer or float32 + Lower limit of score for valid bounding boxes. + + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + + Returns + ------- + valid_boxes: Buffer + 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. + + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + + valid_boxes = ib.buffer_ptr(valid_boxes) + if isinstance(score_threshold, float): + score_threshold = tvm.tir.FloatImm("float32", score_threshold) + id_index = tvm.tir.IntImm("int32", id_index) + score_index = tvm.tir.IntImm("int32", score_index) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - len_inner_for = (batch_size * num_anchors) // nthread_tx + 2 - - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - with ib.for_range(0, len_inner_for, name="i") as i: - idx = tx * len_inner_for + i - batch_idx = idxd(idx, num_anchors) - with ib.if_scope(idx < batch_size): - valid_box_count[idx] = 0 - with ib.if_scope(idx < batch_size * num_anchors): - with ib.if_scope(data[idx] >= 0): - atomic_add_return[batch_idx] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]), - one_count, + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + tid = bx * max_threads + tx + + with ib.if_scope(tid < num_anchors): + i = by + j = tid + score = data[(i * num_anchors + j) * elem_length + score_index] + with ib.if_scope( + tvm.tir.all( + score > score_threshold, + tvm.tir.any( + id_index < 0, data[(i * num_anchors + j) * elem_length + id_index] >= 0 + ), ) - out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = data[idx] - with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] < -num_anchors)): - atomic_add_return[batch_idx] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_box_count[batch_idx]), - one_count, - ) - out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0 + ): + valid_boxes[i * num_anchors + j] = 1 + with ib.else_scope(): + valid_boxes[i * num_anchors + j] = 0 + return ib.get() + + +def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): + """Low level IR to get the ouput indices of valid boxes + and the count of valid boxes + + Parameters + ---------- + valid_boxes: Buffer + 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. + + Returns + ------- + valid_count: Buffer + 1D Buffer of number of valid boxes per batch [batch_size]. + + valid_indices: Buffer + 2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors]. + """ + batch_size = valid_boxes.shape[0] + num_anchors = valid_boxes.shape[1] - with ib.if_scope(idxm(idx, num_anchors) >= valid_box_count[batch_idx]): - out[idx] = -1 + ib = tvm.tir.ir_builder.create() + + valid_boxes = ib.buffer_ptr(valid_boxes) + + valid_count = ib.buffer_ptr(valid_count) + valid_indices = ib.buffer_ptr(valid_indices) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = batch_size // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + # TODO(mbrookhart): Parallelize the sum and cumsum here + current_index = ib.allocate("int32", (1,), name="current_index", scope="local") + with ib.if_scope(tid < batch_size): + current_index[0] = 0 + valid_count[tid] = 0 + with ib.for_range(0, num_anchors) as j: + idx = tid * num_anchors + j + valid_count[tid] = valid_count[tid] + valid_boxes[idx] + with ib.if_scope(valid_boxes[idx] == 1): + valid_indices[idx] = current_index[0] + current_index[0] = current_index[0] + 1 + with ib.else_scope(): + valid_indices[idx] = -1 return ib.get() -def get_valid_counts_ir( - data, valid_count, out, out_indices, score_threshold, id_index, score_index -): +def get_valid_counts_ir(data, valid_indices, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -126,25 +246,16 @@ def get_valid_counts_ir( data : Buffer Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. - valid_count : Buffer - 1D buffer for valid number of boxes with shape [batch_size, ]. - - flag : Buffer + valid_indices: Buffer 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - score_threshold : float32 - Lower limit of score for valid bounding boxes. - - id_index : optional, int - index of the class categories, -1 to disable. - - score_index: optional, int - Index of the scores/confidence of boxes. - Returns ------- - stmt : Stmt - The result IR statement. + out : Buffer + Sorted valid boxes + + out_indices : Buffer + Incidices of valid boxes in original data """ batch_size = data.shape[0] num_anchors = data.shape[1] @@ -154,50 +265,51 @@ def get_valid_counts_ir( data = ib.buffer_ptr(data) - valid_count = ib.buffer_ptr(valid_count) + valid_indices = ib.buffer_ptr(valid_indices) out = ib.buffer_ptr(out) out_indices = ib.buffer_ptr(out_indices) - atomic_add_return = ib.allocate( - valid_count.dtype, (1,), name="atomic_add_return", scope="local" - ) - one_count = tvm.tir.const(1, dtype=valid_count.dtype) one = tvm.tir.const(1, dtype=out.dtype) - score_threshold = tvm.tir.FloatImm("float32", score_threshold) - id_index = tvm.tir.IntImm("int32", id_index) - score_index = tvm.tir.IntImm("int32", score_index) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads - nthread_bx = batch_size * num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - idxd = tvm.tir.indexdiv - - # initialize valid_count - with ib.if_scope(tid < batch_size): - valid_count[tid] = 0 - with ib.if_scope(tid < batch_size * num_anchors): - i = idxd(tid, num_anchors) - with ib.if_scope( - tvm.tir.all( - data[tid * elem_length + score_index] > score_threshold, - tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0), - ) - ): - atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", valid_count[i]), one_count - ) - with ib.for_range(0, elem_length) as k: - out[tid * elem_length + k] = data[tid * elem_length + k] - out_indices[tid + k] = tid + k - with ib.else_scope(): - with ib.for_range(0, elem_length) as k: - out[tid * elem_length + k] = -one - out_indices[tid + k] = -one_count - + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + nthread_bz = elem_length + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_anchors): + i = by + j = tid + k = bz + out[(i * num_anchors + j) * elem_length + k] = -one + out_indices[i * num_anchors + j] = -1 + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_anchors): + i = by + j = tid + k = bz + with ib.if_scope(valid_indices[i, tid] >= 0): + out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[ + (i * num_anchors + j) * elem_length + k + ] + out_indices[i * num_anchors + valid_indices[i, tid]] = j return ib.get() @@ -210,7 +322,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): data : tvm.te.Tensor Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. - score_threshold : optional, float + score_threshold : optional, tvm.te.Tensor or float Lower limit of score for valid bounding boxes. id_index : optional, int @@ -230,23 +342,51 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): batch_size = data.shape[0] num_anchors = data.shape[1] data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + valid_boxes_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "valid_boxes_buf", data_alignment=8 + ) + valid_boxes = te.extern( + [(batch_size, num_anchors)], + [data], + lambda ins, outs: get_valid_boxes_ir( + ins[0], outs[0], score_threshold, id_index, score_index + ), + dtype=["int32"], + in_buffers=[data_buf], + out_buffers=[valid_boxes_buf], + name="get_valid_boxes", + tag="get_valid_boxes_gpu", + ) + + valid_indices_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8 + ) valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8 ) + valid_count, valid_indices = te.extern( + [(batch_size,), (batch_size, num_anchors)], + [valid_boxes], + lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]), + dtype=["int32"], + in_buffers=[valid_boxes_buf], + out_buffers=[valid_count_buf, valid_indices_buf], + name="get_valid_indices", + tag="get_valid_indices_gpu", + ) + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out_indices_buf = tvm.tir.decl_buffer( (batch_size, num_anchors), "int32", "out_buf", data_alignment=8 ) - valid_count, out, out_indices = te.extern( - [(batch_size,), data.shape, (batch_size, num_anchors)], - [data], - lambda ins, outs: get_valid_counts_ir( - ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index - ), + out, out_indices = te.extern( + [data.shape, (batch_size, num_anchors)], + [data, valid_indices], + lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]), dtype=["int32", data.dtype], - in_buffers=[data_buf], - out_buffers=[valid_count_buf, out_buf, out_indices_buf], + in_buffers=[data_buf, valid_indices_buf], + out_buffers=[out_buf, out_indices_buf], name="get_valid_counts", tag="get_valid_counts_gpu", ) @@ -277,12 +417,19 @@ def nms_ir( data : Buffer Buffer of output boxes with class and score. - sort_index : Buffer + sorted_index : Buffer Buffer of output box indexes sorted by score. valid_count : Buffer Buffer of number of valid output boxes. + indices : Buffer + indices in original tensor, with shape [batch_size, num_anchors], + represents the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the second + dimension are like the output of arange(num_anchors) if get_valid_counts + is not used before non_max_suppression. + out : Buffer Output buffer. @@ -308,33 +455,50 @@ def nms_ir( score_index : optional, int Index of the scores/confidence of boxes. + return_indices : boolean + Whether to return box indices in input data. + Returns ------- stmt : Stmt The result IR statement. """ - def calculate_overlap(out_tensor, box_a_idx, box_b_idx): - """Calculate overlap of two boxes.""" - w = tvm.te.max( - 0.0, - tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) - - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx]), + def get_boundaries(output, box_idx): + l = tvm.te.min( + output[box_idx], + output[box_idx + 2], + ) + t = tvm.te.min( + output[box_idx + 1], + output[box_idx + 3], ) - h = tvm.te.max( - 0.0, - tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) - - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]), + r = tvm.te.max( + output[box_idx], + output[box_idx + 2], ) - i = w * h - u = ( - (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) - * (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) - + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) - * (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - - i + b = tvm.te.max( + output[box_idx + 1], + output[box_idx + 3], ) - return tvm.tir.Select(u <= 0.0, 0.0, i / u) + return l, t, r, b + + def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes.""" + a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) + + # Overlapping width and height + w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) + h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) + + # Overlapping area + area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return tvm.tir.Select(u <= 0.0, 0.0, area / u) batch_size = data.shape[0] num_anchors = data.shape[1] @@ -345,60 +509,64 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): data = ib.buffer_ptr(data) sorted_index = ib.buffer_ptr(sorted_index) valid_count = ib.buffer_ptr(valid_count) + indices = ib.buffer_ptr(indices) out = ib.buffer_ptr(out) box_indices = ib.buffer_ptr(box_indices) - indices = ib.buffer_ptr(indices) - num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - j = bx * max_threads + tx - iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) top_k = tvm.tir.IntImm("int32", top_k) coord_start = tvm.tir.IntImm("int32", coord_start) id_index = tvm.tir.IntImm("int32", id_index) score_index = tvm.tir.IntImm("int32", score_index) force_suppress = tvm.tir.IntImm("int32", 1 if force_suppress else 0) - with ib.for_range(0, batch_size, for_type="unroll") as i: + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + nthread_by = batch_size + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + i = by base_idx = i * num_anchors * box_data_length with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Reorder output nkeep = if_then_else( tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i] ) - with ib.if_scope(j < nkeep): + with ib.for_range(0, nkeep) as j: with ib.for_range(0, box_data_length) as k: out[(base_idx + j * box_data_length + k)] = data[ (base_idx + sorted_index[i * num_anchors + j] * box_data_length + k) ] box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])): - with ib.if_scope(j < valid_count[i] - nkeep): + with ib.for_range(0, valid_count[i] - nkeep) as j: with ib.for_range(0, box_data_length) as k: out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 box_indices[i * num_anchors + (j + nkeep)] = -1 + with ib.new_scope(): + nthread_by = batch_size + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + i = by + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.for_range(0, valid_count[i]) as k: - offset_k = k * box_data_length - with ib.if_scope( - tvm.tir.all( - out[base_idx + offset_k + score_index] > 0, - tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), - ) - ): - with ib.if_scope(j < valid_count[i]): + with ib.for_range(0, valid_count[i]) as j: + with ib.for_range(0, j) as k: + offset_k = k * box_data_length + with ib.if_scope( + tvm.tir.all( + out[base_idx + offset_k + score_index] > 0, + tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0), + ) + ): offset_j = j * box_data_length with ib.if_scope( tvm.tir.all( j > k, - out[base_idx + offset_j + score_index] > 0, + out[base_idx + offset_k + score_index] > 0, tvm.tir.any(id_index < 0, out[base_idx + offset_j + id_index] >= 0), tvm.tir.any( force_suppress > 0, @@ -418,21 +586,47 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(id_index >= 0): out[base_idx + offset_j + id_index] = -1.0 box_indices[i * num_anchors + j] = -1 + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + nthread_by = batch_size + nthread_bz = box_data_length + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + tid = bx * max_threads + tx + i = by + j = tid + k = bz + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + pass with ib.else_scope(): with ib.if_scope(j < valid_count[i]): offset_j = j * box_data_length - with ib.for_range(0, box_data_length) as k: - out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] + out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] box_indices[i * num_anchors + j] = j + + with ib.new_scope(): + num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + i = bx + base_idx = i * num_anchors * box_data_length # Set invalid entry to be -1 - with ib.if_scope(j < num_anchors - valid_count[i]): + with ib.for_range(0, num_anchors - valid_count[i]) as j: with ib.for_range(0, box_data_length) as k: out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 box_indices[i * num_anchors + j + valid_count[i]] = -1 # Only return max_output_size number of valid boxes num_valid_boxes[0] = 0 with ib.if_scope(max_output_size > 0): - with ib.if_scope(j < valid_count[i]): + with ib.for_range(0, valid_count[i]) as j: offset_j = j * box_data_length with ib.if_scope(out[base_idx + offset_j] >= 0): with ib.if_scope(num_valid_boxes[0] == max_output_size): @@ -442,11 +636,20 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.else_scope(): num_valid_boxes[0] += 1 - if return_indices: - with ib.if_scope(j < valid_count[i]): - box_idx = box_indices[i * num_anchors + j] - with ib.if_scope(box_idx >= 0): - box_indices[i * num_anchors + j] = indices[i * num_anchors + box_idx] + if return_indices: + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = batch_size // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + i = bx * max_threads + tx + with ib.if_scope(i < batch_size): + with ib.for_range(0, valid_count[i]) as j: + idx = box_indices[i * num_anchors + j] + with ib.if_scope(idx >= 0): + box_indices[i * num_anchors + j] = indices[i * num_anchors + idx] return ib.get() @@ -486,11 +689,11 @@ def non_max_suppression( second dimension are like the output of arange(num_anchors) if get_valid_counts is not used before non_max_suppression. - max_output_size : optional, int + max_output_size : optional, tvm.te.Tensor or int Max number of output valid boxes for each instance. By default all valid boxes are returned. - iou_threshold : optional, float + iou_threshold : optional, tvm.te.Tensor or float Non-maximum suppression threshold. force_suppress : optional, boolean @@ -570,6 +773,8 @@ def non_max_suppression( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8 ) + indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) @@ -597,19 +802,19 @@ def non_max_suppression( name="nms", tag="nms", ) - if return_indices: - out_buf = tvm.tir.decl_buffer( - box_indices.shape, box_indices.dtype, "out_buf", data_alignment=8 - ) + out_shape = box_indices.shape + valid_box_count_shape = [box_indices.shape[0], 1] + valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", "valid_box_count") + output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output") return te.extern( - [box_indices.shape, (batch_size, 1)], + [out_shape, valid_box_count_shape], [box_indices], lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], outs[1]), - dtype=[box_indices.dtype, valid_count.dtype], - in_buffers=[out_buf], - name="rearrange_indices_out", - tag="rearrange_indices_out", + dtype="int32", + out_buffers=[output, valid_box_count], + name="rearrange_indices_out_gpu", + tag="rearrange_indices_out_gpu", ) return out diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 0094ef1adf11..329f0fb897e5 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -104,9 +104,9 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): nthread_bx = shape[axis] // max_threads + 1 tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("vthread") + bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "virtual_thread", nthread_bx) + ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local") if indices_out is not None: @@ -202,9 +202,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): nthread_tx = max_threads nthread_bx = size // max_threads + 1 tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("vthread") + bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "virtual_thread", nthread_bx) + ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index b076fde9ac6e..035d19f25ec7 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -133,7 +133,7 @@ def hybrid_get_valid_counts( Input data. 3-D tensor with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5]. - score_threshold : tvm.tir.const + score_threshold : tvm.te.Tensor Lower limit of score for valid bounding boxes. id_index : tvm.tir.const @@ -213,12 +213,13 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): out_indices: tvm.te.Tensor or numpy NDArray Related index in input data. """ - score_threshold_const = tvm.tir.const(score_threshold, data.dtype) + if isinstance(score_threshold, float): + score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype) id_index_const = tvm.tir.const(id_index, "int32") score_index_const = tvm.tir.const(score_index, "int32") return hybrid_get_valid_counts( data, - score_threshold_const, + score_threshold, id_index_const, score_index_const, tvm.tir.const(1, data.dtype), @@ -281,7 +282,7 @@ def hybrid_nms( Max number of output valid boxes for each instance. Return all valid boxes if max_output_size < 0. - iou_threshold : tvm.tir.const + iou_threshold : tvm.te.Tensor Overlapping(IoU) threshold to suppress object with smaller score. force_suppress : tvm.tir.const @@ -494,7 +495,7 @@ def non_max_suppression( Max number of output valid boxes for each instance. Return all valid boxes if the value of max_output_size is less than 0. - iou_threshold : optional, float + iou_threshold : optional, float or tvm.te.Tensor Non-maximum suppression threshold. force_suppress : optional, boolean @@ -554,6 +555,8 @@ def non_max_suppression( num_anchors = data.shape[1] if isinstance(max_output_size, int): max_output_size = tvm.tir.const(max_output_size, dtype="int32") + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.const(iou_threshold, dtype=data.dtype) score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis]) @@ -567,7 +570,7 @@ def non_max_suppression( batch_size, num_anchors, max_output_size, - tvm.tir.const(iou_threshold, dtype=data.dtype), + iou_threshold, tvm.tir.const(force_suppress, dtype="bool"), tvm.tir.const(top_k, dtype="int32"), tvm.tir.const(coord_start, dtype="int32"), @@ -577,6 +580,7 @@ def non_max_suppression( zero=tvm.tir.const(0, dtype=data.dtype), one=tvm.tir.const(1, dtype=data.dtype), ) + if return_indices: return hybrid_rearrange_indices_out( box_indices, diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index f652644afa3c..bed2510cdf3c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1070,6 +1070,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::FuseOps()); pass_seqs.push_back(transform::ToANormalForm()); + pass_seqs.push_back(transform::InferType()); pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::InlinePrimitives()); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index f21d0967701a..8e9cc625063b 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -111,6 +111,15 @@ class LambdaLifter : public ExprMutator { } captured_vars.push_back(var); } + + Array typed_captured_vars; + Map rebinding_map; + for (auto free_var : captured_vars) { + auto var = Var(free_var->name_hint(), free_var->checked_type()); + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); + } + if (recursive) { if (!captured_vars.empty()) { Array fvs; @@ -122,6 +131,7 @@ class LambdaLifter : public ExprMutator { lambda_map_.emplace(letrec_.back(), global); } } + auto body = Downcast(ExprMutator::VisitExpr_(func_node)); // When performing this optimization there are two cases. @@ -150,7 +160,25 @@ class LambdaLifter : public ExprMutator { if (captured_vars.size() == 0 && free_type_vars.size() == 0) { lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); } else { - lifted_func = Function(captured_vars, body, func->func_type_annotation(), free_type_vars); + // When a closure is locally bound in a program, we have its full type information + // avalible to us. + // + // If we lift the closure out of its bound context it may have free variables which + // do not have type annotations. + // + // In this case we first type check the program assigning a type to all sub-expressions. + // + // We then change the un-annotated free variables into annotated free variables, use + // bind to go from unannotated free variables -> annotated free variables and then + // construct the "closure" function with fully annotated arguments, no longer relying + // on type inference. + auto before = Downcast(body)->params.size(); + auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, + func->type_params, func->attrs, func->span); + auto after = Downcast(rebound_body)->params.size(); + CHECK_EQ(before, after); + lifted_func = + Function(typed_captured_vars, rebound_body, func->func_type_annotation(), free_type_vars); lifted_func = MarkClosure(lifted_func); } @@ -164,6 +192,7 @@ class LambdaLifter : public ExprMutator { global = module_->GetGlobalVar(name); } else { // Add the lifted function to the module. + std::cout << AsText(lifted_func) << std::endl; module_->Add(global, lifted_func); } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 4173d57a84de..34aaf4689a59 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -44,21 +44,29 @@ template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] - ICHECK_EQ(types.size(), 2); + ICHECK_EQ(types.size(), 2) << "the arity of concatenate is 2, not " << types.size(); /* If we receive a tuple we can continue, if we receive * anything but an incomplete type we should signal an * error. */ const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { - throw Error( - ErrorBuilder() << "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0])); + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); + return false; } else if (types[0].as() != nullptr) { return false; } const auto* param = attrs.as(); + if (param == nullptr) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "the call attributes are not defined"); + return false; + } + if (tensor_tuple->fields[0].as()) { return false; } diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 76fdf2829ed0..9316fecddca7 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -31,8 +31,9 @@ TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs); bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 2); + ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); + if (data == nullptr) return false; const auto& dshape = data->shape; ICHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; @@ -44,17 +45,16 @@ bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& att fields.push_back(TensorType(oshape_indices, DataType::Int(32))); // assign output type - reporter->Assign(types[1], TupleType(Array(fields))); + reporter->Assign(types[2], TupleType(Array(fields))); return true; } -Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) { +Expr MakeGetValidCounts(Expr data, Expr score_threshold, int id_index, int score_index) { auto attrs = make_object(); - attrs->score_threshold = score_threshold; attrs->id_index = id_index; attrs->score_index = score_index; static const Op& op = Op::Get("vision.get_valid_counts"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, score_threshold}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts); @@ -64,8 +64,9 @@ RELAY_REGISTER_OP("vision.get_valid_counts") a score threshold. Also moves valid boxes to the top of input data. )doc" TVM_ADD_FILELINE) - .set_num_inputs(1) + .set_num_inputs(2) .add_argument("data", "Tensor", "Input data.") + .add_argument("score_threshold", "Tensor", "Minimum Score.") .set_support_level(5) .add_type_rel("GetValidCount", GetValidCountRel); @@ -73,9 +74,11 @@ TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 5); + ICHECK_EQ(types.size(), 6); const auto* data = types[0].as(); + if (data == nullptr) return false; const auto* valid_count = types[1].as(); + if (valid_count == nullptr) return false; const NonMaximumSuppressionAttrs* param = attrs.as(); const auto& dshape = data->shape; const auto& vshape = valid_count->shape; @@ -90,18 +93,17 @@ bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, fields.push_back(TensorType(oshape, DataType::Int(32))); std::vector countshape({dshape[0], 1}); fields.push_back(TensorType(countshape, DataType::Int(32))); - reporter->Assign(types[4], TupleType(Array(fields))); + reporter->Assign(types[5], TupleType(Array(fields))); } else { - reporter->Assign(types[4], TensorType(dshape, data->dtype)); + reporter->Assign(types[5], TensorType(dshape, data->dtype)); } return true; } -Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, double iou_threshold, +Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, Expr iou_threshold, bool force_suppress, int top_k, int coord_start, int score_index, int id_index, bool return_indices, bool invalid_to_bottom) { auto attrs = make_object(); - attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; attrs->top_k = top_k; attrs->coord_start = coord_start; @@ -110,7 +112,7 @@ Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, do attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; static const Op& op = Op::Get("vision.non_max_suppression"); - return Call(op, {data, valid_count, indices, max_output_size}, Attrs(attrs), {}); + return Call(op, {data, valid_count, indices, max_output_size, iou_threshold}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS); @@ -121,11 +123,12 @@ be in the format of [class_id, score, left, top, right, bottom] or [score, left, top, right, bottom]. Set id_index to be -1 to ignore class_id axis. )doc" TVM_ADD_FILELINE) - .set_num_inputs(4) + .set_num_inputs(5) .add_argument("data", "Tensor", "Input data.") .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.") .add_argument("max_output_size", "Tensor", "Max number of output valid boxes.") + .add_argument("iou_threshold", "Tensor", "Threshold for box overlap.") .set_support_level(5) .add_type_rel("NMS", NMSRel); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d7a07f7271a9..bae50c9d85f4 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -53,10 +53,9 @@ def get_tvm_output_with_vm( mod, params = relay.frontend.from_onnx( graph_def, shape_dict, opset=opset, freeze_params=freeze_params ) - if convert_to_static: - from tvm.relay import transform - mod = transform.DynamicToStatic()(mod) + if convert_to_static: + mod = relay.transform.DynamicToStatic()(mod) ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) result = ex.evaluate()(*input_data) @@ -2821,7 +2820,6 @@ def forward(self, input): def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"): - print(x_shape, kernel_shape, strides, mode, pads, auto_pad) x_np = np.random.uniform(size=x_shape).astype("float32") if mode == "max": @@ -3690,6 +3688,99 @@ def verify_roi_align( verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, spatial_scale=1.0) +# @tvm.testing.uses_gpu +def test_non_max_suppression(): + def verify_nms( + boxes, scores, max_ouput_boxes_per_class, iou_threshold, score_threshold, output_dims + ): + input_names = ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"] + input_nodes = [ + helper.make_tensor_value_info("boxes", TensorProto.FLOAT, boxes.shape), + helper.make_tensor_value_info("scores", TensorProto.FLOAT, scores.shape), + helper.make_tensor_value_info( + "max_output_boxes_per_class", TensorProto.INT64, max_output_boxes_per_class.shape + ), + helper.make_tensor_value_info("iou_threshold", TensorProto.FLOAT, iou_threshold.shape), + ] + inputs = [boxes, scores, max_output_boxes_per_class, iou_threshold] + if score_threshold is not None: + input_names.append("score_threshold") + input_nodes.append( + helper.make_tensor_value_info( + "score_threshold", TensorProto.FLOAT, score_threshold.shape + ) + ) + inputs.append(score_threshold) + node = helper.make_node( + "NonMaxSuppression", + inputs=input_names, + outputs=["Y"], + center_point_box=0, + ) + + graph = helper.make_graph( + [node], + "nms_test", + inputs=input_nodes, + outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, output_dims)], + ) + + model = helper.make_model(graph, producer_name="nms_test") + + verify_with_ort_with_inputs(model, inputs, use_vm=True) + + boxes = np.array( + [ + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.0, 0.0, 0.5, 0.5], + [0.5, 0.5, 0.9, 0.9], + [0.5, 0.5, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.5, 0.5, 0.95, 0.95], + [0.5, 0.5, 0.96, 0.96], + [0.5, 0.5, 1.0, 1.0], + ], + ] + ).astype("float32") + + scores = np.array( + [ + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + ] + ).astype("float32") + max_output_boxes_per_class = np.array(2).astype("int64") + iou_threshold = np.array(0.8).astype("float32") + output_dims = [8, 3] + verify_nms(boxes, scores, max_output_boxes_per_class, iou_threshold, None, output_dims) + + boxes = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, -0.1, 1.0, 0.9], + [0.0, 10.0, 1.0, 11.0], + [0.0, 10.1, 1.0, 11.1], + [0.0, 100.0, 1.0, 101.0], + ] + ] + ).astype(np.float32) + scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32) + max_output_boxes_per_class = np.array([3]).astype(np.int64) + iou_threshold = np.array([0.5]).astype(np.float32) + score_threshold = np.array([0.4]).astype(np.float32) + output_dims = [2, 3] + verify_nms( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_dims + ) + + def verify_cond_loop(): y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1]) y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1]) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index b3b65539cf81..1ce8a182f034 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -314,8 +314,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) - # get_valid_count for cuda, opencl doesn't do data rearrangement - if target in ["cuda", "opencl"]: + # get_valid_count for opencl doesn't do data rearrangement + if target in ["opencl"]: return tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04) diff --git a/tests/python/relay/test_pass_lambda_lift.py b/tests/python/relay/test_pass_lambda_lift.py index b19aebd84ef7..ce737b7bedbb 100644 --- a/tests/python/relay/test_pass_lambda_lift.py +++ b/tests/python/relay/test_pass_lambda_lift.py @@ -34,6 +34,7 @@ def test_basic(): level1_func = relay.Function([x1, y1], level2_func(x1, y1)) mod["main"] = level1_func + mod = relay.transform.InferType()(mod) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 2 @@ -48,6 +49,7 @@ def test_closure(): clo = outer_func(relay.ones(shape=(2,), dtype="float32")) mod["main"] = relay.Function([], relay.Call(clo, [relay.zeros(shape=(2,), dtype="float32")])) + mod = relay.transform.InferType()(mod) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 3 @@ -75,6 +77,7 @@ def test_recursive(): ) mod["main"] = relay.Function([x], ret) + mod = relay.transform.InferType()(mod) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 2