diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index ac660bfb7461..6445bb1fe73f 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1312,6 +1312,10 @@ constexpr const char* fragment_shape = "fragment_shape"; */ constexpr const char* fragment_layout = "fragment_layout"; +/*! + * \brief Mark that the kernel is hand threaded and doesn't need syncs inserted + */ +constexpr const char* hand_threaded = "hand_threaded"; /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index ca832ef0ef36..5ebd3060a6bb 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -57,6 +57,20 @@ def traverse(op): return s +def _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + by = te.thread_axis("blockIdx.y") + bz = te.thread_axis("blockIdx.z") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(bz, "thread_extent", nthread_bz) + + return tx, bx, by, bz + + def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None): """Initialize the output buffers by copying from inputs""" axis_mul_before = 1 @@ -78,16 +92,8 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f # Copy the keys_in to initial output with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) + tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz) tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) idx = (by * shape[axis] + tid) * axis_mul_after + bz with ib.if_scope(tid < shape[axis]): keys_out[idx] = keys_in[idx] @@ -97,6 +103,100 @@ def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_f return axis_mul_before, axis_mul_after +## TODO(mbrookhart): These are effective optimziation hyperparametrs +## Perhaps we can autotune? +block_size = 128 +thread_work = 4 + + +def _odd_even_sort( + ib, + size, + axis_mul_before, + axis_mul_after, + is_ascend, + keys, + keys_swap, + values=None, + values_swap=None, +): + + nthread_tx = block_size // 2 + nthread_bx = ceil_div(size, block_size) + nthread_by = axis_mul_before + nthread_bz = axis_mul_after + with ib.new_scope(): + ib.scope_attr(tvm.tir.const(0), "hand_threaded", 0) + tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz) + tid = 2 * tx + start = bx * block_size + + ## Create shared memory as syncable thread scratch space + tmp_keys_swap = ib.allocate( + keys_swap.dtype, + (block_size,), + name="temp_keys_swap", + scope="shared", + ) + if values_swap is not None: + tmp_values_swap = ib.allocate( + values_swap.dtype, + (block_size,), + name="temp_values_swap", + scope="shared", + ) + + ## Create thread local data for swapping + temp_keys = ib.allocate(keys_swap.dtype, (1,), name="temp_keys", scope="local") + if values_swap is not None: + temp_values = ib.allocate(values_swap.dtype, (1,), name="temp_values", scope="local") + + temp_cond1 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond1", scope="local") + temp_cond2 = ib.allocate(keys_swap.dtype, (1,), name="temp_cond2", scope="local") + # Copy data to scratch space + base_idx = by * size * axis_mul_after + bz + with ib.for_range(0, 2) as n: + with ib.if_scope((tid + n + start) < size): + tmp_keys_swap[tid + n] = keys[base_idx + (tid + n + start) * axis_mul_after] + if values_swap is not None: + tmp_values_swap[tid + n] = values[base_idx + (tid + n + start) * axis_mul_after] + + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + + idxm = tvm.tir.indexmod + # OddEvenTransposeSort + current_sort_num = tvm.tir.min(block_size, size - start) + with ib.for_range(0, current_sort_num) as k: + n = idxm(tid + k, 2) + with ib.if_scope(tid + n < current_sort_num - 1): + temp_cond1[0] = tmp_keys_swap[tid + n] + temp_cond2[0] = tmp_keys_swap[tid + n + 1] + if is_ascend: + cond = temp_cond1[0] > temp_cond2[0] + else: + cond = temp_cond1[0] < temp_cond2[0] + with ib.if_scope(cond): + temp_keys[0] = tmp_keys_swap[tid + n] + tmp_keys_swap[tid + n] = tmp_keys_swap[tid + n + 1] + tmp_keys_swap[tid + n + 1] = temp_keys[0] + if values_swap is not None: + temp_values[0] = tmp_values_swap[tid + n] + tmp_values_swap[tid + n] = tmp_values_swap[tid + n + 1] + tmp_values_swap[tid + n + 1] = temp_values[0] + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + + ## Copy sorted data to output + with ib.for_range(0, 2) as n: + with ib.if_scope(tid + n + start < size): + keys[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n] + keys_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_keys_swap[tid + n] + if values_swap is not None: + values[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[tid + n] + values_swap[base_idx + (tid + n + start) * axis_mul_after] = tmp_values_swap[ + tid + n + ] + + def _sort_common( ib, size, @@ -110,22 +210,22 @@ def _sort_common( ): """Either sort only values or sort values by keys.""" - ## we are looping over the array doing mergesort from the bottom up. - ## The outer loop runs on the host and launches a cuda kernel for each iteration - ## of the algorithm. - ## The basic idea is that at iteration 0, each thread does sort on 2 elements. - ## On iteration 1, each thread merges 2 sorted arrays of 2 elements, - ## to deal with 4 total elements. - ## On iteration 2, each thread merges 2 sorted arrays of 4 elements, - ## to deal with 8 total elements. On iteration 3, each thread deals with 16 elements, etc - ## On the final iteration of the algorithm, one thread will merge two sorted lists - ## to sort the entire array + ## This function performs a multi-level mergesort + ## For blocks of length <= block_size, it does odd-even transpose sort + ## in GPU shared memory + ## For intermediate block sizes (>block_size, < max_threads * thread_work) + ## it uses the mergpath algorthim https://arxiv.org/abs/1406.2628 + ## to merge blocks in parallel + ## At some point, the size of the blocks to be merged is too big for max_threads + ## and we switch to using a dual-level mergepath where the outer mergepath + ## finds the start/end locations of the inner mergepath so that we can split + ## the merge into more blocks max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_by = axis_mul_before * axis_mul_after + nthread_bz = 1 nthread_tx = max_threads - nthread_bx = ceil_div(size, max_threads) - nthread_by = axis_mul_before - nthread_bz = axis_mul_after + nthread_bx = ceil_div(size, nthread_tx) def compare(a, b): """ @@ -137,91 +237,234 @@ def compare(a, b): out = b <= a return out - def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even): - """ - Merge the two sections of the array assigned to this thread - """ - # pylint: disable=arguments-out-of-order - # initialize iterators + # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs + lower_lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64" + ) + + _odd_even_sort( + ib, + size, + axis_mul_before * axis_mul_after, + 1, + is_ascend, + keys, + keys_swap, + values, + values_swap, + ) + + upper_lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" + ) + + def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): + first = ib.allocate("int64", (1,), name="first", scope="local") + mid = ib.allocate("int64", (1,), name="mid", scope="local") + last = ib.allocate("int64", (1,), name="last", scope="local") + first[0] = tvm.te.max(0, diag - bCount) + last[0] = tvm.te.min(diag, aCount) + with ib.while_loop(first[0] < last[0]): + mid = (first[0] + last[0]) >> 1 + a = source[base_idx + (aStart + mid)] + b = source[base_idx + (bStart + diag - 1 - mid)] + with ib.if_scope(compare(a, b)): + first[0] = mid + 1 + with ib.else_scope(): + last[0] = mid + return first[0], last[0] + + def serial_merge( + source, + dest, + source_idx, + dest_idx, + base_idx, + aCount, + bCount, + aStart, + bStart, + kStart, + diag, + step_count, + first, + last, + ): i = ib.allocate("int64", (1,), name="i", scope="local") j = ib.allocate("int64", (1,), name="j", scope="local") - i[0] = start - j[0] = middle - # set up indexes - base_idx = by * size * axis_mul_after + bz - # iterate over the output loop - with ib.for_range(0, end - start) as k: - i_idx = base_idx + i[0] * axis_mul_after - j_idx = base_idx + j[0] * axis_mul_after - k_idx = base_idx + (k + start) * axis_mul_after - - def swap_values(source, dest, source_idx, dest_idx): - def assign_i(): - """assign i value to current output""" - dest[k_idx] = source[i_idx] - if values is not None: - dest_idx[k_idx] = source_idx[i_idx] - i[0] += 1 - - def assign_j(): - """assign j value to current output""" - dest[k_idx] = source[j_idx] - if values is not None: - dest_idx[k_idx] = source_idx[j_idx] - j[0] += 1 - - ## if both of the iterators are in range - with ib.if_scope(tvm.tir.all(i[0] < middle, j[0] < end)): - # compare them and insert whichever is next into the output - with ib.if_scope(compare(source[i_idx], source[j_idx])): - assign_i() - with ib.else_scope(): - assign_j() - # otherwise, simply copy the remainder of the valid iterator to the output - with ib.else_scope(): - with ib.if_scope(i[0] < middle): - assign_i() - with ib.else_scope(): - assign_j() + i[0] = aStart + first + j[0] = bStart + diag - last + with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count: + i_idx = base_idx + i[0] + j_idx = base_idx + j[0] + k_idx = base_idx + (kStart + diag + count) + + def assign_i(): + """assign i value to current output""" + dest[k_idx] = source[i_idx] + if values is not None: + dest_idx[k_idx] = source_idx[i_idx] + i[0] += 1 - # Switch which input is the source and which is the destination each iteration - with ib.if_scope(even): - swap_values(source, dest, source_idx, dest_idx) + def assign_j(): + """assign j value to current output""" + dest[k_idx] = source[j_idx] + if values is not None: + dest_idx[k_idx] = source_idx[j_idx] + j[0] += 1 + + ## if both of the iterators are in range + with ib.if_scope(tvm.tir.all(i[0] < aStart + aCount, j[0] < bStart + bCount)): + # compare them and insert whichever is next into the output + with ib.if_scope(compare(source[i_idx], source[j_idx])): + assign_i() + with ib.else_scope(): + assign_j() + # otherwise, simply copy the remainder of the valid iterator to the output with ib.else_scope(): - swap_values(dest, source, dest_idx, source_idx) - - def mergesort(source, dest, source_idx, dest_idx, size, width, even): - # calculate the start, mid, and end points of this section - start = width * tid - - with ib.if_scope(start < size): - middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64") - end = cast(tvm.te.min(start + width, size), "int64") - # merge the start->middle and middle->end arrays - bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even) + with ib.if_scope(i[0] < aStart + aCount): + assign_i() + with ib.else_scope(): + assign_j() - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" - ) - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << l2_width + with ib.for_range(0, upper_lim - lower_lim, dtype="int64") as l2_width: + width = 2 << (l2_width + lower_lim) # Define and launch the cuda kernel with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - # Reduce the number of blocks as the work per thread grows - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(size, width * max_threads), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) + target = tvm.target.Target.current() + if "vulkan" in str(target): + # Vulkan can't handle dynamic nthread, so we thread slightly differently + # for vulkan. We don't do this generally because it causes a 15% perf + # regression on other platforms + ntx = max_threads + nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32") + nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32") + tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz) + else: + ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32") + nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32") + nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32") + tx, bx, by, bz = _get_threads(ib, ntx, nbx, nthread_by, nbz) + + def mergepath( + source, + dest, + source_idx, + dest_idx, + aCount, + bCount, + aStart, + bStart, + kStart, + step_count, + even, + ): + # pylint: disable=arguments-out-of-order + def merge(source, dest, source_idx, dest_idx): + diag = tx * step_count + first, last = get_merge_begin( + source, + by * size, + aCount, + bCount, + aStart, + bStart, + diag, + step_count, + ) + # iterate over the output loop + serial_merge( + source, + dest, + source_idx, + dest_idx, + by * size, + aCount, + bCount, + aStart, + bStart, + kStart, + diag, + step_count, + first, + last, + ) + + with ib.if_scope(even): + merge(source, dest, source_idx, dest_idx) + with ib.else_scope(): + merge(dest, source, dest_idx, source_idx) + + def mergesort(source, dest, source_idx, dest_idx, size, width, even): + # calculate the start, mid, and end points of this section + start = width * bz + middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64") + end = cast(tvm.te.min(start + width, size), "int64") + with ib.if_scope(start < size): + with ib.if_scope(nbx == 1): + ## merge the start->middle and middle->end arrays + aCount = middle - start + bCount = end - middle + mergepath( + source, + dest, + source_idx, + dest_idx, + aCount, + bCount, + start, + middle, + start, + ceil_div(width, ntx), + even, + ) + with ib.else_scope(): + step_count = max_threads * thread_work + diag = bx * step_count + + def do_merge(first, last): + aStart = start + first + bStart = middle + diag - last + aCount = tvm.te.min(middle - aStart, step_count) + bCount = tvm.te.min(end - bStart, step_count) + mergepath( + source, + dest, + source_idx, + dest_idx, + aCount, + bCount, + aStart, + bStart, + start + diag, + thread_work, + even, + ) + + with ib.if_scope(even): + first, last = get_merge_begin( + source, + by * size, + middle - start, + end - middle, + start, + middle, + diag, + step_count, + ) + do_merge(first, last) + with ib.else_scope(): + first, last = get_merge_begin( + dest, + by * size, + middle - start, + end - middle, + start, + middle, + diag, + step_count, + ) + do_merge(first, last) # Call the kernel mergesort( @@ -233,29 +476,23 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): width, tvm.tir.indexmod(l2_width, 2) == 0, ) - + nthread_by = axis_mul_before + nthread_bz = axis_mul_after + nthread_tx = max_threads + nthread_bx = ceil_div(size, nthread_tx) ## if the final sorted data ended up in the swap, copy it to the real output - with ib.if_scope(tvm.tir.indexmod(lim, 2) == 1): + with ib.if_scope( + tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1) + ): with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) + tx, bx, by, bz = _get_threads(ib, nthread_tx, nthread_bx, nthread_by, nthread_bz) tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - bz = te.thread_axis("blockIdx.z") - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(bz, "thread_extent", nthread_bz) - idx = (by * size + tid) * axis_mul_after + bz + idx = (by * axis_mul_after + bz) * size + tid with ib.if_scope(tid < size): - idx = (by * size + tid) * axis_mul_after + bz keys[idx] = keys_swap[idx] if values is not None: values[idx] = values_swap[idx] - return ib.get() - def sort_ir( data, values_out, values_out_swap, axis, is_ascend, indices_out=None, indices_out_swap=None @@ -301,27 +538,30 @@ def sort_ir( assert indices_out_swap is not None indices_out_swap = ib.buffer_ptr(indices_out_swap) - axis_mul_before, axis_mul_after = _sort_init( - ib, - shape, - axis, - data, - values_out, - indices_out, - value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype), - ) + with ib.if_scope(shape[axis] > 0): + axis_mul_before, axis_mul_after = _sort_init( + ib, + shape, + axis, + data, + values_out, + indices_out, + value_init_func=lambda _, tid: tvm.tir.generic.cast(tid, indices_out.dtype), + ) + + _sort_common( + ib, + shape[axis], + axis_mul_before, + axis_mul_after, + is_ascend, + values_out, + values_out_swap, + values=indices_out, + values_swap=indices_out_swap, + ) - return _sort_common( - ib, - shape[axis], - axis_mul_before, - axis_mul_after, - is_ascend, - values_out, - values_out_swap, - values=indices_out, - values_swap=indices_out_swap, - ) + return ib.get() def sort_by_key_ir( @@ -376,27 +616,29 @@ def sort_by_key_ir( values_out = ib.buffer_ptr(values_out) values_out_swap = ib.buffer_ptr(values_out_swap) - axis_mul_before, axis_mul_after = _sort_init( - ib, - shape, - axis, - keys_in, - keys_out, - values_out, - value_init_func=lambda idx, _: values_in[idx], - ) - - return _sort_common( - ib, - shape[axis], - axis_mul_before, - axis_mul_after, - is_ascend, - keys_out, - keys_out_swap, - values=values_out, - values_swap=values_out_swap, - ) + with ib.if_scope(shape[axis] > 0): + axis_mul_before, axis_mul_after = _sort_init( + ib, + shape, + axis, + keys_in, + keys_out, + values_out, + value_init_func=lambda idx, _: values_in[idx], + ) + + _sort_common( + ib, + shape[axis], + axis_mul_before, + axis_mul_after, + is_ascend, + keys_out, + keys_out_swap, + values=values_out, + values_swap=values_out_swap, + ) + return ib.get() def sort(data, axis=-1, is_ascend=1): @@ -419,16 +661,29 @@ def sort(data, axis=-1, is_ascend=1): out : tvm.te.Tensor The output of this function. """ + ndim = len(data.shape) + axis = ndim + axis if axis < 0 else axis + if axis != ndim - 1: + # Prepare for sorting along axis -1. + axes = swap(list(range(ndim)), axis) + data = transpose(data, axes) + value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8) + out = te.extern( [data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend), + lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend), out_buffers=[value_buf, value_buf_swap], name="sort_gpu", tag="sort_gpu", )[0] + + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + out = transpose(out, axes) + return out @@ -507,10 +762,18 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): out : tvm.te.Tensor The output of this function. """ + ndim = len(data.shape) + axis = ndim + axis if axis < 0 else axis + if axis != ndim - 1: + # Prepare for sorting along axis -1. + axes = swap(list(range(ndim)), axis) + data = transpose(data, axes) + value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_swap_buf", data_alignment=8) indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) + out = te.extern( [data.shape, data.shape, data.shape, data.shape], [data], @@ -518,7 +781,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): ins[0], outs[0], outs[2], - axis, + -1, is_ascend, indices_out=outs[1], indices_out_swap=outs[3], @@ -527,6 +790,11 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): name="argsort_gpu", tag="argsort_gpu", )[1] + + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + out = transpose(out, axes) + return out @@ -625,21 +893,30 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): ndim = len(data.shape) axis = axis + ndim if axis < 0 else axis assert 0 <= axis < ndim + dshape = data.shape + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + data = transpose(data, axes) + values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) values_swap_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "values_swap_buf", data_alignment=8 ) indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8) + if ret_type == "values": output = te.extern( [data.shape, data.shape], [data], - lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], axis, is_ascend), + lambda ins, outs: sort_ir(ins[0], outs[0], outs[1], -1, is_ascend), out_buffers=[values_buf, values_swap_buf], name="topk_gpu", tag="topk_gpu", )[0] + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + output = transpose(output, axes) else: output = te.extern( [data.shape, data.shape, data.shape, data.shape], @@ -648,7 +925,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): ins[0], outs[0], outs[2], - axis, + -1, is_ascend, indices_out=outs[1], indices_out_swap=outs[3], @@ -657,6 +934,11 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): name="topk_gpu", tag="topk_gpu", )[0:2] + if axis != ndim - 1: + axes = swap(list(range(ndim)), axis) + output[0] = transpose(output[0], axes) + output[1] = transpose(output[1], axes) + if isinstance(k, int) and k < 1: if ret_type == "indices": return output[1] @@ -668,7 +950,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): if i == axis: end.append(k if isinstance(k, int) else tvm.te.size_var("dim")) else: - end.append(data.shape[i]) + end.append(dshape[i]) if ret_type == "both": values_out, indices_out = output values_out = strided_slice(values_out, beg, end, strides) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 38143c14b021..00002d3587db 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -132,6 +132,10 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { StmtExprVisitor::VisitStmt_(op); } env_threads_.pop_back(); + } else if (op->attr_key == attr::hand_threaded) { + // skip this pass on blocks that were hand_threaded + // this avoids control flow and read/write conflicts + // between hand-threaded kernels and automatic threading } else { StmtExprVisitor::VisitStmt_(op); } diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 0dac69e36025..f4b785f59df8 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -26,6 +26,7 @@ @tvm.testing.uses_gpu def test_sort(): def verify_sort(shape, axis, is_ascend, is_dyn=False): + if is_dyn: x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32")) else: @@ -87,9 +88,11 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False): for dtype in ["int32", "int64", "float32", "float64"]: verify_argsort((2, 3, 4), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) verify_argsort((1, 4, 6), axis=1, is_ascend=True, dtype=dtype, is_dyn=is_dyn) - verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) - verify_argsort((3, 2000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) - verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + dtype = "int32" + verify_argsort((3, 5, 6), axis=-1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn) + verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn) @tvm.testing.uses_gpu