From 7baf939e714773e7d1322a18ee8e5b0f4126db19 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 23 Dec 2020 12:55:14 +0900 Subject: [PATCH 01/10] sort refactor initial import --- python/tvm/topi/cuda/sort.py | 488 +++++++++++++++++------------------ 1 file changed, 233 insertions(+), 255 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 039ebe3aea4e..7f1347d3615b 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -62,46 +62,14 @@ def traverse(op): return s -def sort_ir( - data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None -): - """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. - - Parameters - ---------- - data: Buffer - Buffer of input data. Data will be sorted in place. - - values_out : Buffer - Output buffer of values of sorted tensor with same shape as data. - - values_out_swap : Buffer - Output buffer of values with same shape as data to use as swap. - - axis : Int - Axis long which to sort the input tensor. - - is_ascend : Boolean - Whether to sort in ascending or descending order. +def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) - indicess_out : Buffer - Output buffer of indices of sorted tensor with same shape as data. - indices_out_swap : Buffer - Output buffer of indices with same shape as data to use as swap. - - Returns - ------- - stmt : Stmt - The result IR statement. - """ - - def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) +def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None): axis_mul_before = 1 axis_mul_after = 1 - shape = data.shape if axis < 0: axis = len(shape) + axis for i, value in enumerate(shape, 0): @@ -110,16 +78,6 @@ def ceil_div(a, b): elif i > axis: axis_mul_after *= value - ib = tvm.tir.ir_builder.create() - - data = ib.buffer_ptr(data) - values_out = ib.buffer_ptr(values_out) - values_out_swap = ib.buffer_ptr(values_out_swap) - if indices_out is not None: - indices_out = ib.buffer_ptr(indices_out) - assert indices_out_swap is not None - indices_out_swap = ib.buffer_ptr(indices_out_swap) - # Set up threading max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads @@ -127,7 +85,7 @@ def ceil_div(a, b): nthread_by = axis_mul_before nthread_bz = axis_mul_after - # Copy the data to initial output + # Copy the keys_in to initial output with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -141,10 +99,22 @@ def ceil_div(a, b): ib.scope_attr(bz, "thread_extent", nthread_bz) idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): - values_out[idx] = data[idx] - if indices_out is not None: - indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype) - + keys_out[idx] = keys_in[idx] + if values_out is not None: + values_out[idx] = value_init_func(idx, tid) + + +def _sort_inplace( + ib, + size, + axis_mul_before, + axis_mul_after, + is_ascend, + keys, + keys_swap, + values=None, + values_swap=None, +): ## we are looping over the array doing mergesort from the bottom up. ## The outer loop runs on the host and launches a cuda kernel for each iteration ## of the algorithm. @@ -155,6 +125,83 @@ def ceil_div(a, b): ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc ## On the final iteration of the algorithm, one thread will merge two sorted lists ## to sort the entire array + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(size, max_threads) + nthread_by = axis_mul_before + nthread_bz = axis_mul_after + + def compare(a, b): + """ + Compare a and b in proper ascending or descending order + """ + if is_ascend: + out = a <= b + else: + out = b <= a + return out + + def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even): + """ + Merge the two sections of the array assigned to this thread + """ + # pylint: disable=arguments-out-of-order + # initialize iterators + i[0] = start + j[0] = middle + # set up indexes + base_idx = by * size * axis_mul_after + bz + # iterate over the output loop + with ib.for_range(0, end - start) as k: + i_idx = base_idx + i[0] * axis_mul_after + j_idx = base_idx + j[0] * axis_mul_after + k_idx = base_idx + (k + start) * axis_mul_after + + def swap_values(source, dest, source_idx, dest_idx): + def assign_i(): + """assign i value to current output""" + dest[k_idx] = source[i_idx] + if values is not None: + dest_idx[k_idx] = source_idx[i_idx] + i[0] += 1 + + def assign_j(): + """assign j value to current output""" + dest[k_idx] = source[j_idx] + if values is not None: + dest_idx[k_idx] = source_idx[j_idx] + j[0] += 1 + + ## if both of the iterators are in range + with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): + # compare them and insert whichever is next into the output + with ib.if_scope(compare(source[i_idx], source[j_idx])): + assign_i() + with ib.else_scope(): + assign_j() + # otherwise, simply copy the remainder of the valid iterator to the output + with ib.else_scope(): + with ib.if_scope(i[0] < middle): + assign_i() + with ib.else_scope(): + assign_j() + + # Switch which input is the source and which is the destination each iteration + with ib.if_scope(even): + swap_values(source, dest, source_idx, dest_idx) + with ib.else_scope(): + swap_values(dest, source, dest_idx, source_idx) + + def mergesort(source, dest, source_idx, dest_idx, size, width, even): + # calculate the start, mid, and end points of this section + start[0] = width * tid + with ib.if_scope(start[0] < size): + middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) + end[0] = tvm.te.min(start[0] + width, size) + ## merge the start->middle and middle->end arrays + bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even) + lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" ) @@ -174,7 +221,7 @@ def ceil_div(a, b): ib.scope_attr( bx, "thread_extent", - tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"), + tvm.tir.generic.cast(ceil_div(size, width * max_threads), "int32"), ) tid = bx * nthread_tx + tx @@ -183,85 +230,13 @@ def ceil_div(a, b): ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(bz, "thread_extent", nthread_bz) - def compare(a, b): - """ - Compare a and b in proper ascending or descending order - """ - if is_ascend: - out = a <= b - else: - out = b <= a - return out - - def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even): - """ - Merge the two sections of the array assigned to this thread - """ - # pylint: disable=arguments-out-of-order - # initialize iterators - i[0] = start - j[0] = middle - # set up indexes - base_idx = by * shape[axis] * axis_mul_after + bz - # iterate over the output loop - with ib.for_range(0, end - start) as k: - i_idx = base_idx + i[0] * axis_mul_after - j_idx = base_idx + j[0] * axis_mul_after - k_idx = base_idx + (k + start) * axis_mul_after - - def swap_values(source, dest, source_idx, dest_idx): - def assign_i(): - """assign i value to current output""" - dest[k_idx] = source[i_idx] - if indices_out is not None: - dest_idx[k_idx] = source_idx[i_idx] - i[0] += 1 - - def assign_j(): - """assign j value to current output""" - dest[k_idx] = source[j_idx] - if indices_out is not None: - dest_idx[k_idx] = source_idx[j_idx] - j[0] += 1 - - ## if both of the iterators are in range - with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): - # compare them and insert whichever is next into the output - with ib.if_scope(compare(source[i_idx], source[j_idx])): - assign_i() - with ib.else_scope(): - assign_j() - # otherwise, simply copy the remainder of the valid iterator to the output - with ib.else_scope(): - with ib.if_scope(i[0] < middle): - assign_i() - with ib.else_scope(): - assign_j() - - # Switch which input is the source and which is the destination each iteration - with ib.if_scope(even): - swap_values(source, dest, source_idx, dest_idx) - with ib.else_scope(): - swap_values(dest, source, dest_idx, source_idx) - - def mergesort(source, dest, source_idx, dest_idx, size, width, even): - # calculate the start, mid, and end points of this section - start[0] = width * tid - with ib.if_scope(start[0] < size): - middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) - end[0] = tvm.te.min(start[0] + width, size) - ## merge the start->middle and middle->end arrays - bottom_up_merge( - source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even - ) - # Call the kernel mergesort( - values_out, - values_out_swap, - indices_out, - indices_out_swap, - shape[axis], + keys, + keys_swap, + values, + values_swap, + size, width, tvm.tir.indexmod(l2_width, 2) == 0, ) @@ -279,29 +254,31 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): bz = te.thread_axis("blockIdx.z") ib.scope_attr(by, "thread_extent", nthread_by) ib.scope_attr(bz, "thread_extent", nthread_bz) - idx = (by * shape[axis] + tid) * axis_mul_after + bz - with ib.if_scope(tid < shape[axis]): - idx = (by * shape[axis] + tid) * axis_mul_after + bz - values_out[idx] = values_out_swap[idx] - if indices_out is not None: - indices_out[idx] = indices_out_swap[idx] + idx = (by * size + tid) * axis_mul_after + bz + with ib.if_scope(tid < size): + idx = (by * size + tid) * axis_mul_after + bz + keys[idx] = keys_swap[idx] + if values is not None: + values[idx] = values_swap[idx] return ib.get() -def sort_nms_ir(data, valid_count, output, axis, is_ascend): +def sort_ir( + data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None +): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. Parameters ---------- data: Buffer - Buffer of input data. + Buffer of input data. Data will be sorted in place. - valid_count : Buffer - 1D Buffer of number of valid number of boxes. + values_out : Buffer + Output buffer of values of sorted tensor with same shape as data. - output : Buffer - Output buffer of indicies of sorted tensor with same shape as data. + values_out_swap : Buffer + Output buffer of values with same shape as data to use as swap. axis : Int Axis long which to sort the input tensor. @@ -309,82 +286,109 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): is_ascend : Boolean Whether to sort in ascending or descending order. + indicess_out : Buffer + Output buffer of indices of sorted tensor with same shape as data. + + indices_out_swap : Buffer + Output buffer of indices with same shape as data to use as swap. + Returns ------- stmt : Stmt The result IR statement. """ - - size = 1 - axis_mul_before = 1 - axis_mul_after = 1 - shape = data.shape - if axis < 0: - axis = len(shape) + axis - for i, value in enumerate(shape, 0): - size *= value - if i < axis: - axis_mul_before *= value - elif i > axis: - axis_mul_after *= value - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) ib = tvm.tir.ir_builder.create() + shape = data.shape + data = ib.buffer_ptr(data) - valid_count = ib.buffer_ptr(valid_count) - output = ib.buffer_ptr(output) - nthread_tx = max_threads - nthread_bx = size // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * nthread_tx + tx - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") - is_ascend = tvm.tir.IntImm("int32", is_ascend) - - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - with ib.for_range(0, axis_mul_before) as i: - with ib.for_range(0, axis_mul_after) as j: - current_sort_num = valid_count[i * axis_mul_after + j] - base_idx = i * shape[axis] * axis_mul_after + j - with ib.if_scope(tid < shape[axis]): - output[base_idx + tid * axis_mul_after] = tid - # OddEvenTransposeSort - with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): - offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after - with ib.if_scope( - tvm.tir.all( - is_ascend == 1, - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - data[offset] > data[offset + axis_mul_after], - ) - ): - temp_data[0] = data[offset] - data[offset] = data[offset + axis_mul_after] - data[offset + axis_mul_after] = temp_data[0] - temp_index[0] = output[offset] - output[offset] = output[offset + axis_mul_after] - output[offset + axis_mul_after] = temp_index[0] - with ib.if_scope( - tvm.tir.all( - is_ascend == 0, - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - data[offset] < data[offset + axis_mul_after], - ) - ): - temp_data[0] = data[offset] - data[offset] = data[offset + axis_mul_after] - data[offset + axis_mul_after] = temp_data[0] - temp_index[0] = output[offset] - output[offset] = output[offset + axis_mul_after] - output[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + values_out = ib.buffer_ptr(values_out) + values_out_swap = ib.buffer_ptr(values_out_swap) + if indices_out is not None: + indices_out = ib.buffer_ptr(indices_out) + assert indices_out_swap is not None + indices_out_swap = ib.buffer_ptr(indices_out_swap) - return ib.get() + axis_mul_before, axis_mul_after = _sort_init( + ib, + shape, + axis, + data, + values_out, + indices_out, + value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype), + ) + + return _sort_inplace( + ib, + shape[axis], + axis_mul_before, + axis_mul_after, + is_ascend, + values_out, + values_out_swap, + values=indices_out, + values_swap=indices_out_swap, + ) + + +def sort_by_key_ir( + keys_in, values_in, keys_out, values_out, keys_out_swap, values_out_swap, axis, is_ascend +): + """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + + Parameters + ---------- + keys_in: Buffer + Buffer of input keys_in. Keys_in will be sorted in place. + + keys_out : Buffer + Output buffer of values of sorted tensor with same shape as keys_in. + + keys_out_swap : Buffer + Output buffer of values with same shape as keys_in to use as swap. + + axis : Int + Axis long which to sort the input tensor. + + is_ascend : Boolean + Whether to sort in ascending or descending order. + + indicess_out : Buffer + Output buffer of indices of sorted tensor with same shape as keys_in. + + values_out_swap : Buffer + Output buffer of indices with same shape as keys_in to use as swap. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.tir.ir_builder.create() + shape = keys_in.shape + + keys_in = ib.buffer_ptr(keys_in) + values_in = ib.buffer_ptr(values_in) + keys_out = ib.buffer_ptr(keys_out) + keys_out_swap = ib.buffer_ptr(keys_out_swap) + values_out = ib.buffer_ptr(values_out) + values_out_swap = ib.buffer_ptr(values_out_swap) + + axis_mul_before, axis_mul_after = _sort_init( + ib, axis, keys_in, keys_out, values_out, value_init_func=lambda idx, _: values_in[idx] + ) + + return _sort_inplace( + ib, + shape[axis], + axis_mul_before, + axis_mul_after, + is_ascend, + keys_out, + keys_out_swap, + values=values_out, + values_swap=values_out_swap, + ) def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"): @@ -534,7 +538,7 @@ def sort_thrust(data, axis=-1, is_ascend=1): return out -def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): +def argsort(data, axis=-1, is_ascend=1, dtype="float32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -543,9 +547,6 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): data: tvm.te.Tensor The input array. - valid_count : tvm.te.Tensor, optional - The number of valid elements to be sorted. - axis : int, optional Axis long which to sort the input tensor. @@ -560,49 +561,26 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): out : tvm.te.Tensor The output of this function. """ - if valid_count is not None: - sorted_data = identity(data) - sorted_data_buf = tvm.tir.decl_buffer( - data.shape, data.dtype, "sorted_data_buf", data_alignment=8 - ) - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4 - ) - out_buf = tvm.tir.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) - out = te.extern( - [data.shape], - [sorted_data, valid_count], - lambda ins, outs: sort_nms_ir(ins[0], ins[1], outs[0], axis, is_ascend), - dtype="int32", - in_buffers=[sorted_data_buf, valid_count_buf], - out_buffers=[out_buf], - name="argsort_nms_gpu", - tag="argsort_nms_gpu", - ) - else: - value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) - value_swap_buf = tvm.tir.decl_buffer( - data.shape, data.dtype, "value_swap_buf", data_alignment=8 - ) - indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) - indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) - out = te.extern( - [data.shape, data.shape, data.shape, data.shape], - [data], - lambda ins, outs: sort_ir( - ins[0], - outs[0], - outs[2], - axis, - is_ascend, - indices_out=outs[1], - indices_out_swap=outs[3], - ), - out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], - name="argsort_gpu", - tag="argsort_gpu", - )[1] - return out + value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_swap_buf", data_alignment=8) + indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) + out = te.extern( + [data.shape, data.shape, data.shape, data.shape], + [data], + lambda ins, outs: sort_ir( + ins[0], + outs[0], + outs[2], + axis, + is_ascend, + indices_out=outs[1], + indices_out_swap=outs[3], + ), + out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], + name="argsort_gpu", + tag="argsort_gpu", + )[1] def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): From 8d47aca8f2a848acb0d58a81a515df80c41ab248 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 23 Dec 2020 13:18:20 +0900 Subject: [PATCH 02/10] sort test working --- python/tvm/topi/cuda/nms.py | 2 +- python/tvm/topi/cuda/sort.py | 51 +++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index cea287edd62e..14152258b29e 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -738,7 +738,7 @@ def non_max_suppression( ) else: sort_tensor = argsort( - score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype + score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype ) sort_tensor_buf = tvm.tir.decl_buffer( diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 7f1347d3615b..94e8a778730c 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -103,6 +103,8 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f if values_out is not None: values_out[idx] = value_init_func(idx, tid) + return axis_mul_before, axis_mul_after + def _sort_inplace( ib, @@ -203,7 +205,7 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even) lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width @@ -581,6 +583,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): name="argsort_gpu", tag="argsort_gpu", )[1] + return out def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): @@ -840,6 +843,52 @@ def schedule_topk(outs): return _schedule_sort(outs) +def sort_by_key(keys, values, axis=-1, is_ascend=1): + """Sort values with respect to keys. Both keys and values will + be sorted and returned. + + Parameters + ---------- + keys: tvm.te.Tensor + The 1D input keys. + + values : tvm.te.Tensor, + The 1D input values. + + Returns + ------- + keys_sorted : tvm.te.Tensor + The sorted keys + + values_sorted : tvm.te.Tensor + The values sorted with respect to the keys + """ + keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8) + values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8) + keys_swap_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_swap_buf", data_alignment=8) + values_swap_buf = tvm.tir.decl_buffer( + values.shape, values.dtype, "values_swap_buf", data_alignment=8 + ) + + out_bufs = [ + tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8), + tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", data_alignment=8), + ] + out = te.extern( + [keys.shape, values.shape, keys.shape, values.shape], + [keys, values], + lambda ins, outs: sort_by_key_ir( + ins[0], ins[1], outs[0], outs[1], outs[2], outs[3], axis, is_ascend + ), + in_buffers=[keys_buf, values_buf], + out_buffers=out_bufs, + dtype=[keys.dtype, values.dtype], + name="sort_by_key", + tag="sort_by_key", + ) + return out[0], out[1] + + def stable_sort_by_key_thrust(keys, values, for_scatter=False): """Sort values with respect to keys using thrust. Both keys and values will be sorted and returned. From 5b507d63a9998b8a920a15a68fe485c2f79b21d3 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 23 Dec 2020 13:39:23 +0900 Subject: [PATCH 03/10] scatter 1d with positive indices working --- python/tvm/topi/cuda/scatter.py | 22 ++++++++-------------- python/tvm/topi/cuda/sort.py | 12 +++++++----- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 9916e2a7fa6d..879e27b45008 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -20,7 +20,7 @@ from tvm import te from ..scatter import _verify_scatter_nd_inputs from .nms import atomic_add -from .sort import stable_sort_by_key_thrust, is_thrust_available +from .sort import stable_sort_by_key_thrust, is_thrust_available, sort_by_key def ceil_div(a, b): @@ -417,7 +417,7 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): return ib.get() -def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): +def gen_scatter_1d_sorted(data, indices_sorted, updates_sorted, axis, out, _): """Generate scatter ir for 1d inputs, using a sorting based approach. By sorting indices and comparing neighboring two indices, we can tell which of elements in the indices tensor can scatter its update value into the output. @@ -473,12 +473,6 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): ni = indices_sorted.shape[0] - def do_update(ib, index, update): - with ib.if_scope(index < 0): - out_ptr[index + n] = update - with ib.else_scope(): - out_ptr[index] = update - with ib.new_scope(): nthread_bx = ceil_div(ni, nthread_tx) tx = te.thread_axis("threadIdx.x") @@ -491,7 +485,7 @@ def do_update(ib, index, update): # The last element can always update. index = indices_ptr[tid] update = updates_ptr[tid] - do_update(ib, index, update) + out_ptr[index] = update with ib.else_scope(): with ib.if_scope(tid < ni - 1): @@ -503,7 +497,7 @@ def do_update(ib, index, update): # This thread can update the output. with ib.if_scope(index != index_next): update = updates_ptr[tid] - do_update(ib, index, update) + out_ptr[index] = update return ib.get() @@ -539,7 +533,7 @@ def scatter(data, indices, updates, axis=0): assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions" ir_funcs = { - 1: gen_ir_1d, + 1: gen_scatter_1d_sorted, 2: gen_ir_2d, 3: gen_ir_3d, 4: gen_ir_4d, @@ -553,14 +547,14 @@ def update_func(dst_ptr, dst_index, update): in_bufs = [data] - if rank == 1 and is_thrust_available(): - ir_funcs[1] = gen_scatter_1d_thrust + if False and rank == 1 and is_thrust_available(): indices_sorted, updates_sorted = stable_sort_by_key_thrust( indices, updates, for_scatter=True ) in_bufs += [indices_sorted, updates_sorted] else: - in_bufs += [indices, updates] + indices_sorted, updates_sorted = sort_by_key(indices, updates) + in_bufs += [indices_sorted, updates_sorted] out = te.extern( [out_shape], diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 94e8a778730c..115cefd1911b 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -106,7 +106,7 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f return axis_mul_before, axis_mul_after -def _sort_inplace( +def _sort_common( ib, size, axis_mul_before, @@ -320,7 +320,7 @@ def sort_ir( value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype), ) - return _sort_inplace( + return _sort_common( ib, shape[axis], axis_mul_before, @@ -377,10 +377,10 @@ def sort_by_key_ir( values_out_swap = ib.buffer_ptr(values_out_swap) axis_mul_before, axis_mul_after = _sort_init( - ib, axis, keys_in, keys_out, values_out, value_init_func=lambda idx, _: values_in[idx] + ib, shape, axis, keys_in, keys_out, values_out, value_init_func=lambda idx, _: values_in[idx] ) - return _sort_inplace( + return _sort_common( ib, shape[axis], axis_mul_before, @@ -872,7 +872,9 @@ def sort_by_key(keys, values, axis=-1, is_ascend=1): out_bufs = [ tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8), - tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", data_alignment=8), + tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8), + tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_swap_buf", data_alignment=8), + tvm.tir.decl_buffer(values.shape, values.dtype, "values_swap_buf", data_alignment=8), ] out = te.extern( [keys.shape, values.shape, keys.shape, values.shape], From 3e87a23bc01152ec30aeb1dc6c0fca3df80dca52 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 23 Dec 2020 14:54:08 +0900 Subject: [PATCH 04/10] remove negatiev indices, using extern for now --- python/tvm/topi/cuda/scatter.py | 51 ++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 879e27b45008..7356b0116b5e 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -18,6 +18,8 @@ """Scatter operator """ import tvm from tvm import te +from .. import tag +from ..transform import where, shape from ..scatter import _verify_scatter_nd_inputs from .nms import atomic_add from .sort import stable_sort_by_key_thrust, is_thrust_available, sort_by_key @@ -502,6 +504,52 @@ def gen_scatter_1d_sorted(data, indices_sorted, updates_sorted, axis, out, _): return ib.get() +def _remove_negative_indices(indices): + """Convert negative indices to corresponding positive indices""" + + def _ir(indices, out): + size = indices.shape[0] + ib = tvm.tir.ir_builder.create() + + indices = ib.buffer_ptr(indices) + out = ib.buffer_ptr(out) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + tid = bx * max_threads + tx + with ib.if_scope(tid < size): + with ib.if_scope(indices[tid] < 0): + out[tid] = indices[tid] + size + with ib.else_scope(): + out[tid] = indices[tid] + + return ib.get() + + indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) + out_indices_buf = tvm.tir.decl_buffer( + indices.shape, indices.dtype, "out_indices_buf", data_alignment=8 + ) + return te.extern( + [indices.shape], + [indices], + lambda ins, outs: _ir( + ins[0], + outs[0], + ), + dtype=[indices.dtype], + in_buffers=[indices_buf], + out_buffers=[out_indices_buf], + name="remove_negative_indices", + tag="remove_negative_indices", + ) + + def scatter(data, indices, updates, axis=0): """Update data at positions defined by indices with values in updates @@ -547,12 +595,13 @@ def update_func(dst_ptr, dst_index, update): in_bufs = [data] - if False and rank == 1 and is_thrust_available(): + if rank == 1 and is_thrust_available(): indices_sorted, updates_sorted = stable_sort_by_key_thrust( indices, updates, for_scatter=True ) in_bufs += [indices_sorted, updates_sorted] else: + indices = _remove_negative_indices(indices) indices_sorted, updates_sorted = sort_by_key(indices, updates) in_bufs += [indices_sorted, updates_sorted] From afb86fafe04395c72d71b129bec1344bc6ea6187 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 15:01:09 +0900 Subject: [PATCH 05/10] minor fix --- python/tvm/topi/cuda/scatter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 7356b0116b5e..e11226312efb 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -18,8 +18,6 @@ """Scatter operator """ import tvm from tvm import te -from .. import tag -from ..transform import where, shape from ..scatter import _verify_scatter_nd_inputs from .nms import atomic_add from .sort import stable_sort_by_key_thrust, is_thrust_available, sort_by_key @@ -424,7 +422,7 @@ def gen_scatter_1d_sorted(data, indices_sorted, updates_sorted, axis, out, _): By sorting indices and comparing neighboring two indices, we can tell which of elements in the indices tensor can scatter its update value into the output. Sorting of indices, and sorting of updates with respect to indices, can be done - at the same time by thrust's sort_by_key function. It is important that sorting + at the same time by sort_by_key function. It is important that sorting be done in a "stable" way via stable_sort, to guarantee deterministic output. Parameters @@ -505,7 +503,7 @@ def gen_scatter_1d_sorted(data, indices_sorted, updates_sorted, axis, out, _): def _remove_negative_indices(indices): - """Convert negative indices to corresponding positive indices""" + """Convert negative indices to corresponding positive indices.""" def _ir(indices, out): size = indices.shape[0] From 1334260bab0d3981fea0e7d0a104cca4cf4c68b4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 16:02:50 +0900 Subject: [PATCH 06/10] minor fix --- python/tvm/topi/cuda/sort.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 115cefd1911b..209cbfbc23fb 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -865,10 +865,6 @@ def sort_by_key(keys, values, axis=-1, is_ascend=1): """ keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8) values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8) - keys_swap_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_swap_buf", data_alignment=8) - values_swap_buf = tvm.tir.decl_buffer( - values.shape, values.dtype, "values_swap_buf", data_alignment=8 - ) out_bufs = [ tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8), From 7049ede48051a8060ef0884c13f61fd011daf24c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 16:11:11 +0900 Subject: [PATCH 07/10] add sort by key test --- tests/python/contrib/test_sort.py | 35 ++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index 9d6eb7cb3a1e..f338276ca118 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm import te -from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available +from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available, sort_by_key import numpy as np @@ -123,7 +123,40 @@ def test_thrust_stable_sort_by_key(): tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) +def test_sort_by_key_gpu(): + size = 6 + keys = te.placeholder((size,), name="keys", dtype="int32") + values = te.placeholder((size,), name="values", dtype="int32") + + for target in ["cuda", "nvptx", "opencl", "rocm"]: + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + continue + + with tvm.target.Target(target): + keys_out, values_out = sort_by_key(keys, values) + ctx = tvm.context(target) + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) + + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + + if __name__ == "__main__": test_sort() test_sort_np() test_thrust_stable_sort_by_key() + test_sort_by_key_gpu() From d9ed4d2dad6df6263ded6adabd7dab506a6eb1ac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 16:17:45 +0900 Subject: [PATCH 08/10] revert scatter change --- python/tvm/topi/cuda/scatter.py | 61 +++++---------------------------- 1 file changed, 8 insertions(+), 53 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index e11226312efb..be602c8ab7a3 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -20,7 +20,7 @@ from tvm import te from ..scatter import _verify_scatter_nd_inputs from .nms import atomic_add -from .sort import stable_sort_by_key_thrust, is_thrust_available, sort_by_key +from .sort import stable_sort_by_key_thrust, is_thrust_available def ceil_div(a, b): @@ -417,13 +417,15 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): return ib.get() -def gen_scatter_1d_sorted(data, indices_sorted, updates_sorted, axis, out, _): +def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): """Generate scatter ir for 1d inputs, using a sorting based approach. By sorting indices and comparing neighboring two indices, we can tell which of elements in the indices tensor can scatter its update value into the output. Sorting of indices, and sorting of updates with respect to indices, can be done - at the same time by sort_by_key function. It is important that sorting + at the same time by thrust's sort_by_key function. It is important that sorting be done in a "stable" way via stable_sort, to guarantee deterministic output. + Negative indices are assumed to have been converted to corresponding positive + indices. Parameters ---------- @@ -502,52 +504,6 @@ def gen_scatter_1d_sorted(data, indices_sorted, updates_sorted, axis, out, _): return ib.get() -def _remove_negative_indices(indices): - """Convert negative indices to corresponding positive indices.""" - - def _ir(indices, out): - size = indices.shape[0] - ib = tvm.tir.ir_builder.create() - - indices = ib.buffer_ptr(indices) - out = ib.buffer_ptr(out) - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = ceil_div(size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - - tid = bx * max_threads + tx - with ib.if_scope(tid < size): - with ib.if_scope(indices[tid] < 0): - out[tid] = indices[tid] + size - with ib.else_scope(): - out[tid] = indices[tid] - - return ib.get() - - indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, "indices_buf", data_alignment=8) - out_indices_buf = tvm.tir.decl_buffer( - indices.shape, indices.dtype, "out_indices_buf", data_alignment=8 - ) - return te.extern( - [indices.shape], - [indices], - lambda ins, outs: _ir( - ins[0], - outs[0], - ), - dtype=[indices.dtype], - in_buffers=[indices_buf], - out_buffers=[out_indices_buf], - name="remove_negative_indices", - tag="remove_negative_indices", - ) - - def scatter(data, indices, updates, axis=0): """Update data at positions defined by indices with values in updates @@ -579,7 +535,7 @@ def scatter(data, indices, updates, axis=0): assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions" ir_funcs = { - 1: gen_scatter_1d_sorted, + 1: gen_ir_1d, 2: gen_ir_2d, 3: gen_ir_3d, 4: gen_ir_4d, @@ -594,14 +550,13 @@ def update_func(dst_ptr, dst_index, update): in_bufs = [data] if rank == 1 and is_thrust_available(): + ir_funcs[1] = gen_scatter_1d_thrust indices_sorted, updates_sorted = stable_sort_by_key_thrust( indices, updates, for_scatter=True ) in_bufs += [indices_sorted, updates_sorted] else: - indices = _remove_negative_indices(indices) - indices_sorted, updates_sorted = sort_by_key(indices, updates) - in_bufs += [indices_sorted, updates_sorted] + in_bufs += [indices, updates] out = te.extern( [out_shape], From 13d706fdb0d20a82059f97af2a3b4d5e39780170 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 16:30:04 +0900 Subject: [PATCH 09/10] add document --- python/tvm/topi/cuda/sort.py | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 209cbfbc23fb..18872a242160 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -21,7 +21,6 @@ from tvm._ffi import get_global_func from .injective import schedule_injective_from_existing -from ..math import identity from ..transform import strided_slice, transpose from .. import tag @@ -67,7 +66,7 @@ def ceil_div(a, b): def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None): - + """Initialize the output buffers by copying from inputs""" axis_mul_before = 1 axis_mul_after = 1 if axis < 0: @@ -117,6 +116,8 @@ def _sort_common( values=None, values_swap=None, ): + """Either sort only values or sort values by keys.""" + ## we are looping over the array doing mergesort from the bottom up. ## The outer loop runs on the host and launches a cuda kernel for each iteration ## of the algorithm. @@ -269,7 +270,7 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): def sort_ir( data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None ): - """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. Parameters ---------- @@ -336,19 +337,28 @@ def sort_ir( def sort_by_key_ir( keys_in, values_in, keys_out, values_out, keys_out_swap, values_out_swap, axis, is_ascend ): - """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + """Low level IR to do sort by key on the GPU. Parameters ---------- keys_in: Buffer - Buffer of input keys_in. Keys_in will be sorted in place. + Buffer of input keys. + + values_in: Buffer + Buffer of input keys. keys_out : Buffer - Output buffer of values of sorted tensor with same shape as keys_in. + Buffer of output sorted keys. + + values_out : Buffer + Buffer of output sorted values. keys_out_swap : Buffer Output buffer of values with same shape as keys_in to use as swap. + values_out_swap : Buffer + Output buffer of values with same shape as values_in to use as swap. + axis : Int Axis long which to sort the input tensor. @@ -377,7 +387,13 @@ def sort_by_key_ir( values_out_swap = ib.buffer_ptr(values_out_swap) axis_mul_before, axis_mul_after = _sort_init( - ib, shape, axis, keys_in, keys_out, values_out, value_init_func=lambda idx, _: values_in[idx] + ib, + shape, + axis, + keys_in, + keys_out, + values_out, + value_init_func=lambda idx, _: values_in[idx], ) return _sort_common( @@ -850,10 +866,16 @@ def sort_by_key(keys, values, axis=-1, is_ascend=1): Parameters ---------- keys: tvm.te.Tensor - The 1D input keys. + The input keys. values : tvm.te.Tensor, - The 1D input values. + The input values. + + axis : int, optional + Axis long which to sort the input tensor. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. Returns ------- From f79abc87af775909f3c56ab3039f08ed9ea80028 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 16:33:45 +0900 Subject: [PATCH 10/10] fix py format --- python/tvm/topi/cuda/nms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 14152258b29e..020cf9b5bc63 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -737,9 +737,7 @@ def non_max_suppression( score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype ) else: - sort_tensor = argsort( - score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype - ) + sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype) sort_tensor_buf = tvm.tir.decl_buffer( sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8