From f6488d5a836e2a0453b753b5cc486321edadb7ad Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 11 Dec 2020 22:38:43 -0700 Subject: [PATCH 01/11] implement parallel cuda mergesort fix lint --- python/tvm/topi/cuda/sort.py | 241 ++++++++++++++++++++------- tests/python/relay/test_any.py | 4 +- tests/python/relay/test_op_level6.py | 5 +- 3 files changed, 183 insertions(+), 67 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index ea149054fa65..43aabded1cdb 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument, no-else-return """Sort related operators """ import tvm from tvm import te @@ -62,7 +62,9 @@ def traverse(op): return s -def sort_ir(data, values_out, axis, is_ascend, indices_out=None): +def sort_ir( + data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None +): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. Parameters @@ -94,64 +96,155 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): axis_mul_before *= value elif i > axis: axis_mul_after *= value - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + ib = tvm.tir.ir_builder.create() + data = ib.buffer_ptr(data) values_out = ib.buffer_ptr(values_out) + values_out_swap = ib.buffer_ptr(values_out_swap) if indices_out is not None: indices_out = ib.buffer_ptr(indices_out) + assert indices_out_swap is not None + indices_out_swap = ib.buffer_ptr(indices_out_swap) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = shape[axis] // max_threads + 1 - - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * nthread_tx + tx - temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local") - if indices_out is not None: - temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local") - - with ib.for_range(0, axis_mul_before) as i: - with ib.for_range(0, axis_mul_after) as j: - base_idx = i * shape[axis] * axis_mul_after + j + nthread_by = axis_mul_before + nthread_bz = axis_mul_after + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + idx = (by * shape[axis] + tid) * axis_mul_after + bz + with ib.if_scope(tid < shape[axis]): + idx = (by * shape[axis] + tid) * axis_mul_after + bz + values_out[idx] = data[idx] + if indices_out is not None: + indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype) + + source = values_out + dest = values_out_swap + source_idx = indices_out + dest_idx = indices_out_swap + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int32" + ) + with ib.for_range(0, lim) as l2_width: + width = 2 << l2_width + slices = tvm.tir.indexdiv(shape[axis], (max_threads * width)) + 1 + + with ib.new_scope(): + i = ib.allocate("int32", (1,), name="i", scope="local") + j = ib.allocate("int32", (1,), name="j", scope="local") + start = ib.allocate("int32", (1,), name="start", scope="local") + middle = ib.allocate("int32", (1,), name="middle", scope="local") + end = ib.allocate("int32", (1,), name="end", scope="local") + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + + def compare(a, b): + if is_ascend: + out = a <= b + else: + out = b <= a + return out + + def BottomUpMerge(source, dest, source_idx, dest_idx, start, middle, end, even): + # pylint: disable=arguments-out-of-order + i[0] = start + j[0] = middle + base_idx = by * shape[axis] * axis_mul_after + bz + with ib.for_range(0, end - start) as k: + i_idx = base_idx + i[0] * axis_mul_after + j_idx = base_idx + j[0] * axis_mul_after + k_idx = base_idx + (k + start) * axis_mul_after + + def swap_values(source, dest, source_idx, dest_idx): + def assign_i(): + dest[k_idx] = source[i_idx] + if indices_out is not None: + dest_idx[k_idx] = source_idx[i_idx] + i[0] += 1 + + def assign_j(): + dest[k_idx] = source[j_idx] + if indices_out is not None: + dest_idx[k_idx] = source_idx[j_idx] + j[0] += 1 + + with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): + with ib.if_scope(compare(source[i_idx], source[j_idx])): + assign_i() + with ib.else_scope(): + assign_j() + with ib.else_scope(): + with ib.if_scope(i[0] < middle): + assign_i() + with ib.else_scope(): + assign_j() + + with ib.if_scope(even): + swap_values(source, dest, source_idx, dest_idx) + with ib.else_scope(): + swap_values(dest, source, dest_idx, source_idx) + + def MergeSort(source, dest, source_idx, dest_idx, size, width, slices, even): + start[0] = width * tid * slices + with ib.for_range(0, slices): + with ib.if_scope(start[0] < size): + middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) + end[0] = tvm.te.min(start[0] + width, size) + BottomUpMerge( + source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even + ) + start[0] += width + + MergeSort( + source, + dest, + source_idx, + dest_idx, + shape[axis], + width, + slices, + tvm.tir.indexmod(l2_width, 2) == 0, + ) + + with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1): + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): - values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after] + idx = (by * shape[axis] + tid) * axis_mul_after + bz + values_out[idx] = values_out_swap[idx] if indices_out is not None: - indices_out[base_idx + tid * axis_mul_after] = tvm.tir.generic.cast( - tid, indices_out.dtype - ) - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - with ib.for_range(0, axis_mul_before) as i: - with ib.for_range(0, axis_mul_after) as j: - current_sort_num = shape[axis] - base_idx = i * shape[axis] * axis_mul_after + j - # OddEvenTransposeSort - with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): - offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after - if is_ascend: - cond = tvm.tir.all( - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - values_out[offset] > values_out[offset + axis_mul_after], - ) - else: - cond = tvm.tir.all( - 2 * tid + idxm(k, 2) + 1 < current_sort_num, - values_out[offset] < values_out[offset + axis_mul_after], - ) - with ib.if_scope(cond): - temp_data[0] = values_out[offset] - values_out[offset] = values_out[offset + axis_mul_after] - values_out[offset + axis_mul_after] = temp_data[0] - if indices_out is not None: - temp_index[0] = indices_out[offset] - indices_out[offset] = indices_out[offset + axis_mul_after] - indices_out[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + indices_out[idx] = indices_out_swap[idx] return ib.get() @@ -449,12 +542,24 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): ) else: value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + value_swap_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "value_swap_buf", data_alignment=8 + ) indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) out = te.extern( - [data.shape, data.shape], + [data.shape, data.shape, data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), - out_buffers=[value_buf, indices_buf], + lambda ins, outs: sort_ir( + ins[0], + outs[0], + outs[2], + axis, + is_ascend, + indices_out=outs[1], + indices_out_swap=outs[3], + ), + out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], name="argsort_gpu", tag="argsort_gpu", )[1] @@ -564,25 +669,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): axis = axis + ndim if axis < 0 else axis assert 0 <= axis < ndim values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) + values_swap_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "values_swap_buf", data_alignment=8 + ) indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) + indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8) if ret_type == "values": output = te.extern( - [data.shape], + [data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend), - out_buffers=[values_buf], + lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend), + out_buffers=[values_buf, values_swap_buf], name="topk_gpu", tag="topk_gpu", - ) + )[0] else: output = te.extern( - [data.shape, data.shape], + [data.shape, data.shape, data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), - out_buffers=[values_buf, indices_buf], + lambda ins, outs: sort_ir( + ins[0], + outs[0], + outs[2], + axis, + is_ascend, + indices_out=outs[1], + indices_out_swap=outs[3], + ), + out_buffers=[values_buf, indices_buf, values_swap_buf, indices_swap_buf], name="topk_gpu", tag="topk_gpu", - ) + )[0:2] if isinstance(k, int) and k < 1: if ret_type == "indices": return output[1] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index dfc03c0cf6b1..43292d640e19 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -250,9 +250,7 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): check_result([data], mod, expected, flatten=True) -# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have -# to use thrust to guarantee the correct results which has been tested locally. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_any_argwhere(): verify_any_argwhere(any_dims(1), (5,)) verify_any_argwhere(any_dims(2), (5, 5)) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index a5ce1fdcf589..c56783dd9c07 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -66,9 +66,9 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): func = relay.Function([x], z) x_data = np.random.uniform(size=shape).astype("float32") if is_ascend: - ref_res = np.argsort(x_data, axis=axis) + ref_res = np.argsort(x_data, axis=axis, kind="stable") else: - ref_res = np.argsort(-x_data, axis=axis) + ref_res = np.argsort(-x_data, axis=axis, kind="stable") if is_dyn: backends = ["vm", "debug"] @@ -86,6 +86,7 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn) verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype) @tvm.testing.uses_gpu From 2a50289f75e9c4d51c2152744121144017c21d84 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 14 Dec 2020 16:22:11 -0700 Subject: [PATCH 02/11] fix a bug in build module when optimizing the host section of mixed host/device code --- python/tvm/driver/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 058bd62d6226..dc9d741b2726 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -277,7 +277,7 @@ def _build_for_device(input_mod, target, target_host): lambda f: "calling_conv" not in f.attrs or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH ), - tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)), + tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)), tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerCustomDatatypes(), From 18179892006b5569dfce5e9dd3dafd4ed9885c4f Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 14 Dec 2020 22:29:36 -0700 Subject: [PATCH 03/11] convert loop indices to int64 to prevent overflow in start calculation --- python/tvm/topi/cuda/sort.py | 14 +++++++------- tests/python/relay/test_op_level6.py | 5 ++++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 43aabded1cdb..f71f491d36a0 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -136,18 +136,18 @@ def sort_ir( source_idx = indices_out dest_idx = indices_out_swap lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int32" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" ) - with ib.for_range(0, lim) as l2_width: + with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width slices = tvm.tir.indexdiv(shape[axis], (max_threads * width)) + 1 with ib.new_scope(): - i = ib.allocate("int32", (1,), name="i", scope="local") - j = ib.allocate("int32", (1,), name="j", scope="local") - start = ib.allocate("int32", (1,), name="start", scope="local") - middle = ib.allocate("int32", (1,), name="middle", scope="local") - end = ib.allocate("int32", (1,), name="end", scope="local") + i = ib.allocate("int64", (1,), name="i", scope="local") + j = ib.allocate("int64", (1,), name="j", scope="local") + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index c56783dd9c07..fa1f6ebfbc76 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -53,6 +53,8 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False): verify_sort((2, 3, 4), axis=0, is_ascend=False, is_dyn=is_dyn) verify_sort((1, 4, 6), axis=1, is_ascend=True, is_dyn=is_dyn) verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn) + verify_sort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_sort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) @tvm.testing.uses_gpu @@ -86,7 +88,8 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn) verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) - verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype) + verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) @tvm.testing.uses_gpu From dcf18e11fab021e9fbe7a1ee9bed2e43fecd9bca Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 14 Dec 2020 23:18:00 -0700 Subject: [PATCH 04/11] comments and cleanup --- python/tvm/topi/cuda/sort.py | 92 ++++++++++++++++++++++++++---------- 1 file changed, 66 insertions(+), 26 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index f71f491d36a0..657e8816ea25 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -26,6 +26,10 @@ from .. import tag +def ceil_div(a, b): + return (a + b - 1) // b + + def swap(arr, axis): """ swap arr[axis] and arr[-1] """ return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]] @@ -72,8 +76,11 @@ def sort_ir( data: Buffer Buffer of input data. Data will be sorted in place. - output : Buffer - Output buffer of indicies of sorted tensor with same shape as data. + values_out : Buffer + Output buffer of values of sorted tensor with same shape as data. + + values_out_swap : Buffer + Output buffer of values with same shape as data to use as swap. axis : Int Axis long which to sort the input tensor. @@ -81,11 +88,21 @@ def sort_ir( is_ascend : Boolean Whether to sort in ascending or descending order. + indicess_out : Buffer + Output buffer of indices of sorted tensor with same shape as data. + + indices_out_swap : Buffer + Output buffer of indices with same shape as data to use as swap. + Returns ------- stmt : Stmt The result IR statement. """ + + def ceil_div(a, b): + return tvm.tir.indexdiv(a + b - 1, b) + axis_mul_before = 1 axis_mul_after = 1 shape = data.shape @@ -107,12 +124,14 @@ def sort_ir( assert indices_out_swap is not None indices_out_swap = ib.buffer_ptr(indices_out_swap) + # Set up threading max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads - nthread_bx = shape[axis] // max_threads + 1 + nthread_bx = ceil_div(shape[axis], max_threads) nthread_by = axis_mul_before nthread_bz = axis_mul_after + # Copy the data to initial output with ib.new_scope(): tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -131,17 +150,17 @@ def sort_ir( if indices_out is not None: indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype) - source = values_out - dest = values_out_swap - source_idx = indices_out - dest_idx = indices_out_swap + ## we are looping over the array doing mergesort from the bottom up. + ## The outer loop runs on the host and launches a cuda kernel for each iteration + ## of the algorithm. + ## The basic idea is that at iteration 0, each thread does sort on 2 elements. On iteration 1, each thread merges 2 sorted arrays of 2 elements, to deal with 4 total elements. On iteration 2, each thread merges 2 sorted arrays of 4 elements, to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc + ## On the final iteration of the algorithm, one thread will merge two sorted lists to sort the entire array lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width - slices = tvm.tir.indexdiv(shape[axis], (max_threads * width)) + 1 - + # Define and launch the cuda kernel with ib.new_scope(): i = ib.allocate("int64", (1,), name="i", scope="local") j = ib.allocate("int64", (1,), name="j", scope="local") @@ -151,7 +170,12 @@ def sort_ir( tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) + # Reduce the number of blocks as the work per thread grows + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(shape[axis], width * max_threads), "int32"), + ) tid = bx * nthread_tx + tx by = te.thread_axis("blockIdx.y") @@ -160,6 +184,9 @@ def sort_ir( ib.scope_attr(bz, "thread_extent", nthread_bz) def compare(a, b): + """ + Compare a and b in proper ascending or descending order + """ if is_ascend: out = a <= b else: @@ -167,10 +194,16 @@ def compare(a, b): return out def BottomUpMerge(source, dest, source_idx, dest_idx, start, middle, end, even): + """ + Merge the two sections of the array assigned to this thread + """ # pylint: disable=arguments-out-of-order + # initialize iterators i[0] = start j[0] = middle + # set up indexes base_idx = by * shape[axis] * axis_mul_after + bz + # iterate over the output loop with ib.for_range(0, end - start) as k: i_idx = base_idx + i[0] * axis_mul_after j_idx = base_idx + j[0] * axis_mul_after @@ -178,55 +211,62 @@ def BottomUpMerge(source, dest, source_idx, dest_idx, start, middle, end, even): def swap_values(source, dest, source_idx, dest_idx): def assign_i(): + """assign i value to current output""" dest[k_idx] = source[i_idx] if indices_out is not None: dest_idx[k_idx] = source_idx[i_idx] i[0] += 1 def assign_j(): + """assign j value to current output""" dest[k_idx] = source[j_idx] if indices_out is not None: dest_idx[k_idx] = source_idx[j_idx] j[0] += 1 + ## if both of the iterators are in range with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): + # compare them and insert whichever is next into the output with ib.if_scope(compare(source[i_idx], source[j_idx])): assign_i() with ib.else_scope(): assign_j() + # otherwise, simply copy the remainder of the valid iterator to the output with ib.else_scope(): with ib.if_scope(i[0] < middle): assign_i() with ib.else_scope(): assign_j() + # Switch which input is the source and which is the destination each iteration with ib.if_scope(even): swap_values(source, dest, source_idx, dest_idx) with ib.else_scope(): swap_values(dest, source, dest_idx, source_idx) - def MergeSort(source, dest, source_idx, dest_idx, size, width, slices, even): - start[0] = width * tid * slices - with ib.for_range(0, slices): - with ib.if_scope(start[0] < size): - middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) - end[0] = tvm.te.min(start[0] + width, size) - BottomUpMerge( - source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even - ) - start[0] += width - + def MergeSort(source, dest, source_idx, dest_idx, size, width, even): + # calculate the start, mid, and end points of this section + start[0] = width * tid + with ib.if_scope(start[0] < size): + middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) + end[0] = tvm.te.min(start[0] + width, size) + ## merge the start->middle and middle->end arrays + BottomUpMerge( + source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even + ) + + # Call the kernel MergeSort( - source, - dest, - source_idx, - dest_idx, + values_out, + values_out_swap, + indices_out, + indices_out_swap, shape[axis], width, - slices, tvm.tir.indexmod(l2_width, 2) == 0, ) + ## if the final sorted data ended up in the swap, copy it to the real output with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1): with ib.new_scope(): tx = te.thread_axis("threadIdx.x") From 84a2e2d5b721ee6e04a31013e54e00683395d3d2 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 14 Dec 2020 23:37:16 -0700 Subject: [PATCH 05/11] fix lint --- python/tvm/topi/cuda/sort.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 657e8816ea25..8dc507e3576c 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -26,10 +26,6 @@ from .. import tag -def ceil_div(a, b): - return (a + b - 1) // b - - def swap(arr, axis): """ swap arr[axis] and arr[-1] """ return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]] @@ -145,7 +141,6 @@ def ceil_div(a, b): ib.scope_attr(bz, "thread_extent", nthread_bz) idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): - idx = (by * shape[axis] + tid) * axis_mul_after + bz values_out[idx] = data[idx] if indices_out is not None: indices_out[idx] = tvm.tir.generic.cast(tid, indices_out.dtype) @@ -153,8 +148,13 @@ def ceil_div(a, b): ## we are looping over the array doing mergesort from the bottom up. ## The outer loop runs on the host and launches a cuda kernel for each iteration ## of the algorithm. - ## The basic idea is that at iteration 0, each thread does sort on 2 elements. On iteration 1, each thread merges 2 sorted arrays of 2 elements, to deal with 4 total elements. On iteration 2, each thread merges 2 sorted arrays of 4 elements, to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc - ## On the final iteration of the algorithm, one thread will merge two sorted lists to sort the entire array + ## The basic idea is that at iteration 0, each thread does sort on 2 elements. + ## On iteration 1, each thread merges 2 sorted arrays of 2 elements, + ## to deal with 4 total elements. + ## On iteration 2, each thread merges 2 sorted arrays of 4 elements, + ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc + ## On the final iteration of the algorithm, one thread will merge two sorted lists + ## to sort the entire array lim = tvm.tir.generic.cast( tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(shape[axis], "float64"))), "int64" ) From 2d6d51662527f4f2adf80e646c4bff19a0c0a4e0 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Wed, 16 Dec 2020 12:14:53 -0700 Subject: [PATCH 06/11] fix python casing --- python/tvm/topi/cuda/sort.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 8dc507e3576c..e4d00eb332dd 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -193,7 +193,7 @@ def compare(a, b): out = b <= a return out - def BottomUpMerge(source, dest, source_idx, dest_idx, start, middle, end, even): + def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even): """ Merge the two sections of the array assigned to this thread """ @@ -244,19 +244,19 @@ def assign_j(): with ib.else_scope(): swap_values(dest, source, dest_idx, source_idx) - def MergeSort(source, dest, source_idx, dest_idx, size, width, even): + def mergesort(source, dest, source_idx, dest_idx, size, width, even): # calculate the start, mid, and end points of this section start[0] = width * tid with ib.if_scope(start[0] < size): middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size) end[0] = tvm.te.min(start[0] + width, size) ## merge the start->middle and middle->end arrays - BottomUpMerge( + bottom_up_merge( source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even ) # Call the kernel - MergeSort( + mergesort( values_out, values_out_swap, indices_out, From 5bc7a41fdcf5ae6af7d8ab31d1db79ff7e3d95d5 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 18 Dec 2020 10:15:06 -0700 Subject: [PATCH 07/11] enable more flaky tests --- tests/python/relay/test_any.py | 4 +--- tests/python/topi/python/test_topi_argwhere.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 43292d640e19..cfa9825cd7e7 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -837,8 +837,7 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): check_result(in_vals, mod, ref_out) -# TODO(kevinthesun): enable this test when Thrust is available in ci. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") verify_any_topk(any_dims(2), 2, (6, 3), "int32") @@ -1387,7 +1386,6 @@ def test_any_where(): any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4) ) - # TODO(kevinthesun): enable gpu test when Thrust is available in ci. # @tvm.testing.uses_gpu def test_non_max_suppression(): diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index 433030863a43..cef761c93fb3 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -66,9 +66,7 @@ def check_device(device, ctx): check_device(target, ctx) -# TODO(zhiics) Enable argwhere gpu test after sort is fixed. Otherwise, we have -# to use thrust to guarantee the correct results which has been tested locally. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_argwhere(): verify_argwhere((1,)) verify_argwhere((100,)) From 1cbbb8f4a2662cc1926fa2b6a263e8b5d31dac88 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 18 Dec 2020 13:50:28 -0700 Subject: [PATCH 08/11] fix lint --- tests/python/relay/test_any.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index cfa9825cd7e7..e6812aa3bbfa 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1386,6 +1386,7 @@ def test_any_where(): any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4) ) + # TODO(kevinthesun): enable gpu test when Thrust is available in ci. # @tvm.testing.uses_gpu def test_non_max_suppression(): From e0b89946af25aee3c8258f6636b552b3b915cc99 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 18 Dec 2020 15:36:27 -0700 Subject: [PATCH 09/11] fix lint, really enable test --- python/tvm/topi/cuda/sort.py | 1 + tests/python/topi/python/test_topi_argwhere.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index e4d00eb332dd..ce3baaa35170 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -450,6 +450,7 @@ def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32") def sort(data, axis=-1, is_ascend=1): + # pylint: disable=no-value-for-parameter """Performs sorting along the given axis and returns an array of sorted values with the same shape as the input data. diff --git a/tests/python/topi/python/test_topi_argwhere.py b/tests/python/topi/python/test_topi_argwhere.py index cef761c93fb3..69993d287b79 100644 --- a/tests/python/topi/python/test_topi_argwhere.py +++ b/tests/python/topi/python/test_topi_argwhere.py @@ -60,9 +60,6 @@ def check_device(device, ctx): tvm.testing.assert_allclose(args[-1].asnumpy(), np.array(np_out)) for target, ctx in tvm.testing.enabled_targets(): - # TODO(zhiics) Enable argwhere gpu test after sort is fixed. - if ctx.device_type != 1: - continue check_device(target, ctx) From 7a5318067e836f81095d2c411b49ece6739ff4e9 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 18 Dec 2020 19:17:19 -0700 Subject: [PATCH 10/11] fix bad rebase --- python/tvm/topi/cuda/sort.py | 7 +++---- tests/python/relay/test_op_level6.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index ce3baaa35170..512be3626118 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -450,7 +450,6 @@ def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32") def sort(data, axis=-1, is_ascend=1): - # pylint: disable=no-value-for-parameter """Performs sorting along the given axis and returns an array of sorted values with the same shape as the input data. @@ -472,12 +471,12 @@ def sort(data, axis=-1, is_ascend=1): """ dtype = "float32" value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) - indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8) out = te.extern( [data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), - out_buffers=[value_buf, indices_buf], + lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend), + out_buffers=[value_buf, value_buf_swap], name="sort_gpu", tag="sort_gpu", )[0] diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index fa1f6ebfbc76..0dac69e36025 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -53,8 +53,8 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False): verify_sort((2, 3, 4), axis=0, is_ascend=False, is_dyn=is_dyn) verify_sort((1, 4, 6), axis=1, is_ascend=True, is_dyn=is_dyn) verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn) - verify_sort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) - verify_sort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn) + verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn) @tvm.testing.uses_gpu From 7abb3a0c50c65379f0ed385862564bf3ddfbea67 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 18 Dec 2020 20:42:27 -0700 Subject: [PATCH 11/11] fix lint --- python/tvm/topi/cuda/sort.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 512be3626118..039ebe3aea4e 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -469,7 +469,6 @@ def sort(data, axis=-1, is_ascend=1): out : tvm.te.Tensor The output of this function. """ - dtype = "float32" value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8) out = te.extern(