From 930043e8b7a7aa4edd4afa35886f6d653272e88b Mon Sep 17 00:00:00 2001 From: wenxizhu Date: Thu, 15 Jul 2021 17:30:44 +0800 Subject: [PATCH 01/19] [TOPI][CUDA] Improve the performance of scatter_nd by: 1. Split into 2 kernels, one does the "Init" and another does the "Update". Thus they can have different Grid/Block configurations to better utilize SMs. 2. Use atomic_add instead of direct assignment, which could avoid the race condtion when multiple indices point to the same location of the output tensor. With this moidification, it's safe now to use more CUDA threads to gain more parallelism. --- python/tvm/topi/cuda/scatter.py | 56 ++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index c697b648786e..4808a3293523 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -787,42 +787,42 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i - # For now we avoid parallizing over dimensions indexed by `indices` as - # there may be repeated indices and hadling parallel accumulation can - # be hard. So we parallelize over X_M .. X_{N-1} instead. This will - # work well when these dimensions are large enough to saturate memory - # bandwidth, but performance will be bad when these dimensions are - # small. - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - tdim = min(max_threads, fused_updates_dimension) - ib.scope_attr(tx, "thread_extent", tdim) - bdim = ceil_div(fused_updates_dimension, tdim) - ib.scope_attr(bx, "thread_extent", bdim) - - # Copy data into the output. This loop writes to the same portions of - # memory as the following loop, so we do not need a memory sync. - with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i: - index = i * fused_updates_dimension + bx * tdim + tx - with ib.if_scope(bx * tdim + tx < fused_updates_dimension): + # Init output tensor. + with ib.new_scope(): + bidx = te.thread_axis("blockIdx.x") + tidx = te.thread_axis("threadIdx.x") + gridDim = 1 + for i in data_ptr.shape[:-1]: + gridDim *= i + blockDim = data_ptr.shape[-1] + + ib.scope_attr(bidx, "thread_extent", gridDim) + ib.scope_attr(tidx, "thread_extent", blockDim) + index = bidx * blockDim + tidx + with ib.if_scope(index < fused_shape): out[index] = data[index] - with ib.for_range(0, fused_indices_dimension) as i: - j = bx * tdim + tx + # Update output tensor by given values. + with ib.new_scope(): + bidx = te.thread_axis("blockIdx.x") + tidx = te.thread_axis("threadIdx.x") + gridDim = fused_indices_dimension # 32 * 600 = 19200 + blockDim = fused_updates_dimension + ib.scope_attr(bidx, "thread_extent", gridDim) + ib.scope_attr(tidx, "thread_extent", blockDim) + + j = tidx with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): + findex = j + for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0 # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] + findex += offset * indices[bidx + l * fused_indices_dimension] offset *= data_ptr.shape[l] if mode == "update": - out[index] = updates[i * fused_updates_dimension + j] + out[findex] = updates[bidx * fused_updates_dimension + tidx] elif mode == "add": - out[index] += updates[i * fused_updates_dimension + j] + out[findex] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of", out[findex]), updates[bidx * fused_updates_dimension + j]) else: raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) From 833561ba2238b7d6b0a5072b06724ba812cd9069 Mon Sep 17 00:00:00 2001 From: wenxizhu Date: Thu, 15 Jul 2021 17:35:01 +0800 Subject: [PATCH 02/19] Fix python code format. --- python/tvm/topi/cuda/scatter.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 4808a3293523..f4ffcaf97fbc 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -806,7 +806,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.new_scope(): bidx = te.thread_axis("blockIdx.x") tidx = te.thread_axis("threadIdx.x") - gridDim = fused_indices_dimension # 32 * 600 = 19200 + gridDim = fused_indices_dimension # 32 * 600 = 19200 blockDim = fused_updates_dimension ib.scope_attr(bidx, "thread_extent", gridDim) ib.scope_attr(tidx, "thread_extent", blockDim) @@ -815,14 +815,17 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension findex = j - for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0 + for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0 # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] findex += offset * indices[bidx + l * fused_indices_dimension] offset *= data_ptr.shape[l] if mode == "update": out[findex] = updates[bidx * fused_updates_dimension + tidx] elif mode == "add": - out[findex] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of", out[findex]), updates[bidx * fused_updates_dimension + j]) + out[findex] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out[findex]), + updates[bidx * fused_updates_dimension + j], + ) else: raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) From a6effec646db1963b9df8478440f9b500fe0ce81 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Tue, 20 Jul 2021 19:33:25 +0800 Subject: [PATCH 03/19] FIX: [TOPI][CUDA] Improve the performance of scatter_nd #8479 - Split ScatterND kernel into 2 sub-kernels using ib.new_scope() - Replace ib.for_range() with blockIdx.y - Using atomic_add when mode == "add" - Keep threadIdx.x less than max_threads of GPU --- python/tvm/topi/cuda/scatter.py | 60 +++++++++++++++------------------ 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index f4ffcaf97fbc..5929ca199b42 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -772,9 +772,8 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - # We combine all the indices dimensions but the first one into a single - # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the update dimensions. + atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") + fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i @@ -787,45 +786,42 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i - # Init output tensor. + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + tdim = min(max_threads, fused_updates_dimension) + with ib.new_scope(): - bidx = te.thread_axis("blockIdx.x") - tidx = te.thread_axis("threadIdx.x") - gridDim = 1 - for i in data_ptr.shape[:-1]: - gridDim *= i - blockDim = data_ptr.shape[-1] - - ib.scope_attr(bidx, "thread_extent", gridDim) - ib.scope_attr(tidx, "thread_extent", blockDim) - index = bidx * blockDim + tidx + bdim = ceil_div(fused_shape, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim) + ib.scope_attr(tx, "thread_extent", tdim) + + index = bx * tdim + tx with ib.if_scope(index < fused_shape): out[index] = data[index] - # Update output tensor by given values. with ib.new_scope(): - bidx = te.thread_axis("blockIdx.x") - tidx = te.thread_axis("threadIdx.x") - gridDim = fused_indices_dimension # 32 * 600 = 19200 - blockDim = fused_updates_dimension - ib.scope_attr(bidx, "thread_extent", gridDim) - ib.scope_attr(tidx, "thread_extent", blockDim) - - j = tidx + bdim_x = ceil_div(fused_updates_dimension, tdim) + bdim_y = fused_indices_dimension + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(by, "thread_extent", bdim_y) + ib.scope_attr(tx, "thread_extent", tdim) + + j = bx * tdim + tx with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension - findex = j - for l in reversed(range(indices_ptr.shape[0].value)): # 2, 1, 0 - # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - findex += offset * indices[bidx + l * fused_indices_dimension] + index = j + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[by * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[by + l * fused_indices_dimension] offset *= data_ptr.shape[l] if mode == "update": - out[findex] = updates[bidx * fused_updates_dimension + tidx] + out[index] = updates[by * fused_updates_dimension + j] elif mode == "add": - out[findex] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", out[findex]), - updates[bidx * fused_updates_dimension + j], - ) + atomic_add_return[0] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of", out[index]), updates[by * fused_updates_dimension + j]) else: raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) From 675947e6951caa74a86a851c8dbf26d646c20023 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Wed, 21 Jul 2021 17:13:07 +0800 Subject: [PATCH 04/19] Comment added --- python/tvm/topi/cuda/scatter.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 5929ca199b42..a17317f880cd 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -772,7 +772,9 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") + atomic_add_return = ib.allocate( + updates.dtype, (1,), name="atomic_add_return", scope="local" + ) fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: @@ -786,6 +788,17 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i + # For now we avoid parallizing over dimensions indexed by `indices` as + # there may be repeated indices and hadling parallel accumulation can + # be hard. So we parallelize over X_M .. X_{N-1} instead. This will + # work well when these dimensions are large enough to saturate memory + # bandwidth, but performance will be bad when these dimensions are + # small. + + # For better performance, we introduce blockIdx.y to implement for-loops + # within one thread. + # Atomic_add guarantees correctness when mode=="add" + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) tdim = min(max_threads, fused_updates_dimension) @@ -814,14 +827,18 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension index = j + up_index = by * fused_updates_dimension + j for l in reversed(range(indices_ptr.shape[0].value)): # indices[by * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[by + l * fused_indices_dimension] offset *= data_ptr.shape[l] if mode == "update": - out[index] = updates[by * fused_updates_dimension + j] + out[index] = updates[up_index] elif mode == "add": - atomic_add_return[0] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of", out[index]), updates[by * fused_updates_dimension + j]) + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out[index]), + updates[up_index], + ) else: raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) From 1e1a617541be16f92e3206141f31e008e6d9fd38 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Thu, 22 Jul 2021 17:03:52 +0800 Subject: [PATCH 05/19] Add fallback implementation when "mode=add" meets int64 - Atomic_add from CUDA doesn't support int64 data type - Change "ind{i}" to "ind%d"%i, where names of relay.var could correctly display --- python/tvm/topi/cuda/scatter.py | 83 +++++++++++++++++----------- tests/python/relay/test_op_level3.py | 2 +- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index a17317f880cd..c0b7a906b4b1 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -790,14 +790,12 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # For now we avoid parallizing over dimensions indexed by `indices` as # there may be repeated indices and hadling parallel accumulation can - # be hard. So we parallelize over X_M .. X_{N-1} instead. This will - # work well when these dimensions are large enough to saturate memory - # bandwidth, but performance will be bad when these dimensions are - # small. + # be hard. So we parallelize over X_M .. X_{N-1} instead. # For better performance, we introduce blockIdx.y to implement for-loops # within one thread. - # Atomic_add guarantees correctness when mode=="add" + # The code is parallel over the scattered indices, so we use atomic_add + # to guarantee correctness when mode=="add" max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) tdim = min(max_threads, fused_updates_dimension) @@ -814,33 +812,54 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = data[index] with ib.new_scope(): - bdim_x = ceil_div(fused_updates_dimension, tdim) - bdim_y = fused_indices_dimension - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(bx, "thread_extent", bdim_x) - ib.scope_attr(by, "thread_extent", bdim_y) - ib.scope_attr(tx, "thread_extent", tdim) - - j = bx * tdim + tx - with ib.if_scope(j < fused_updates_dimension): - offset = fused_updates_dimension - index = j - up_index = by * fused_updates_dimension + j - for l in reversed(range(indices_ptr.shape[0].value)): - # indices[by * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[by + l * fused_indices_dimension] - offset *= data_ptr.shape[l] - if mode == "update": - out[index] = updates[up_index] - elif mode == "add": - atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", out[index]), - updates[up_index], - ) - else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + if updates.dtype == 'int64' and mode == 'add': + bdim_x = ceil_div(fused_updates_dimension, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(tx, "thread_extent", tdim) + with ib.for_range(0, fused_indices_dimension) as i: + j = bx * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + out[index] += updates[i * fused_updates_dimension + j] + else: + bdim_x = ceil_div(fused_updates_dimension, tdim) + bdim_y = fused_indices_dimension + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(by, "thread_extent", bdim_y) + ib.scope_attr(tx, "thread_extent", tdim) + + j = bx * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part + # of the index into out. + up_index = by * fused_updates_dimension + j + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[by * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[by + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[up_index] + elif mode == "add": + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out[index]), + updates[up_index], + ) + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fc67f0b90295..21bedd438902 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1884,7 +1884,7 @@ def verify_scatter_nd_with_stack( ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices_vars = [ - relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) + relay.var("ind%d"%i, shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) ] updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) From afdd9e717d191c33f88db63088339c7318301f11 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Thu, 22 Jul 2021 17:26:53 +0800 Subject: [PATCH 06/19] Python format --- python/tvm/topi/cuda/scatter.py | 4 ++-- tests/python/relay/test_op_level3.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index c0b7a906b4b1..472515d65fa2 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -812,7 +812,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = data[index] with ib.new_scope(): - if updates.dtype == 'int64' and mode == 'add': + if updates.dtype == "int64" and mode == "add": bdim_x = ceil_div(fused_updates_dimension, tdim) bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") @@ -843,7 +843,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): j = bx * tdim + tx with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. + index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part # of the index into out. up_index = by * fused_updates_dimension + j diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 21bedd438902..96a7a7a95e49 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1884,7 +1884,8 @@ def verify_scatter_nd_with_stack( ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices_vars = [ - relay.var("ind%d"%i, shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) + relay.var("ind%d" % i, shape=v.shape, dtype=str(v.dtype)) + for i, v in enumerate(indices_np) ] updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) From 4f22477de8992033550d85cff73f29606f8f2a41 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Thu, 22 Jul 2021 17:41:39 +0800 Subject: [PATCH 07/19] Fix line too long --- python/tvm/topi/cuda/scatter.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 472515d65fa2..4dfbf20f2fd2 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -823,10 +823,12 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. + # Build up the + # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. for l in reversed(range(indices_ptr.shape[0].value)): - # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + # indices[i * l * fused_indices_dimension] = indices[l, y_0, + # ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] offset *= data_ptr.shape[l] out[index] += updates[i * fused_updates_dimension + j] @@ -844,8 +846,8 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. up_index = by * fused_updates_dimension + j for l in reversed(range(indices_ptr.shape[0].value)): # indices[by * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] From fd573c58b8d841f9222b5f5016615265e573394c Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Sat, 24 Jul 2021 10:03:24 +0800 Subject: [PATCH 08/19] CI pass From a4373d08878effb90598adf1412d99209eda42d8 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Sat, 24 Jul 2021 17:14:22 +0800 Subject: [PATCH 09/19] Empty, for CI pass From d3fb5a209afdd23020c1146d383908efde774adb Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Sun, 25 Jul 2021 09:59:00 +0800 Subject: [PATCH 10/19] Empty, for CI pass From 1faa97a6cf2ae3beed83c2d94a0655257574909e Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Mon, 26 Jul 2021 14:54:57 +0800 Subject: [PATCH 11/19] Empty, for CI pass From 92af1835bdbf0477cf64339461388ba69c70fbe7 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Mon, 26 Jul 2021 16:09:27 +0800 Subject: [PATCH 12/19] Empty, for CI pass From 7d940b04461dab9be648460cef568e71722f336f Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Tue, 27 Jul 2021 16:24:19 +0800 Subject: [PATCH 13/19] Empty, for CI pass From c264949edf1dd27d52052d0c7723f5aa7fc75d04 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Wed, 28 Jul 2021 12:14:49 +0800 Subject: [PATCH 14/19] Exchange blockIdx.x and blockIdx.y --- python/tvm/topi/cuda/scatter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 4dfbf20f2fd2..4ad7743f84b1 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -835,8 +835,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): else: bdim_x = ceil_div(fused_updates_dimension, tdim) bdim_y = fused_indices_dimension - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") + # In case of large input sizes, bim_y might be too large. + # So it could be moved to blockIdx.x position, which holds larger scales. + bx = te.thread_axis("blockIdx.y") + by = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") ib.scope_attr(bx, "thread_extent", bdim_x) ib.scope_attr(by, "thread_extent", bdim_y) From c319e39a453e72d20967d422aa28fa3aefd90f47 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Thu, 29 Jul 2021 11:43:13 +0800 Subject: [PATCH 15/19] check for Vulkan or metal --- python/tvm/topi/cuda/scatter.py | 43 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 4ad7743f84b1..77c36473cfe1 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -764,6 +764,9 @@ def scatter_nd(data, indices, updates, mode): """ _verify_scatter_nd_inputs(data, indices, updates) + def cur_target_kind(kind="cuda"): + return tvm.target.Target.current(allow_none=False).kind == tvm.target.Target(kind).kind + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = tvm.tir.ir_builder.create() @@ -788,15 +791,6 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i - # For now we avoid parallizing over dimensions indexed by `indices` as - # there may be repeated indices and hadling parallel accumulation can - # be hard. So we parallelize over X_M .. X_{N-1} instead. - - # For better performance, we introduce blockIdx.y to implement for-loops - # within one thread. - # The code is parallel over the scattered indices, so we use atomic_add - # to guarantee correctness when mode=="add" - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) tdim = min(max_threads, fused_updates_dimension) @@ -811,8 +805,19 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.if_scope(index < fused_shape): out[index] = data[index] + # For better performance, we introduce blockIdx.y to implement for-loops + # within one thread. + # The code is parallel over the scattered indices, so we use atomic_add + # to guarantee correctness when mode=="add" + + # For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64" + # So we fallback to normal algorithm using "+=" rather than atomic_add with ib.new_scope(): - if updates.dtype == "int64" and mode == "add": + if ( + cur_target_kind("vulkan") + or cur_target_kind("metal") + or (updates.dtype == "int64" and mode == "add") + ): bdim_x = ceil_div(fused_updates_dimension, tdim) bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") @@ -833,27 +838,27 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): offset *= data_ptr.shape[l] out[index] += updates[i * fused_updates_dimension + j] else: - bdim_x = ceil_div(fused_updates_dimension, tdim) - bdim_y = fused_indices_dimension - # In case of large input sizes, bim_y might be too large. + bdim_x = fused_indices_dimension + bdim_y = ceil_div(fused_updates_dimension, tdim) + # In case of large input sizes, fused_indices_dimension might be too large. # So it could be moved to blockIdx.x position, which holds larger scales. - bx = te.thread_axis("blockIdx.y") - by = te.thread_axis("blockIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") tx = te.thread_axis("threadIdx.x") ib.scope_attr(bx, "thread_extent", bdim_x) ib.scope_attr(by, "thread_extent", bdim_y) ib.scope_attr(tx, "thread_extent", tdim) - j = bx * tdim + tx + j = by * tdim + tx with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] # part of the index into out. - up_index = by * fused_updates_dimension + j + up_index = bx * fused_updates_dimension + j for l in reversed(range(indices_ptr.shape[0].value)): - # indices[by * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[by + l * fused_indices_dimension] + # indices[bx * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[bx + l * fused_indices_dimension] offset *= data_ptr.shape[l] if mode == "update": out[index] = updates[up_index] From bac7b65977e98a84ce0766437a52b3340ac80385 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Thu, 29 Jul 2021 17:10:53 +0800 Subject: [PATCH 16/19] Fallback to previous algorithm when mode==update --- python/tvm/topi/cuda/scatter.py | 36 +++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 77c36473cfe1..b211ac0d5364 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -811,10 +811,18 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # to guarantee correctness when mode=="add" # For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64" - # So we fallback to normal algorithm using "+=" rather than atomic_add + # So we fallback to normal algorithm, using "+=" rather than atomic_add + + # TODO: + # Since multiple threads compete for the same write index, which leads to + # non-determinstic output for update mode. We could add a new attribute + # "allow_non_deterministic" to scatter_nd op, which is False by default. + # And change ONNX frontend to emit scatter_op with allow_non_deterministic = True, + # which will allow the new code path for update mode as well with ib.new_scope(): if ( - cur_target_kind("vulkan") + mode == "update" + or cur_target_kind("vulkan") or cur_target_kind("metal") or (updates.dtype == "int64" and mode == "add") ): @@ -836,8 +844,13 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] offset *= data_ptr.shape[l] - out[index] += updates[i * fused_updates_dimension + j] - else: + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + elif mode == "add": bdim_x = fused_indices_dimension bdim_y = ceil_div(fused_updates_dimension, tdim) # In case of large input sizes, fused_indices_dimension might be too large. @@ -860,15 +873,12 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # indices[bx * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[bx + l * fused_indices_dimension] offset *= data_ptr.shape[l] - if mode == "update": - out[index] = updates[up_index] - elif mode == "add": - atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", out[index]), - updates[up_index], - ) - else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out[index]), + updates[up_index], + ) + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() From 3cf534ce4f7c7136c421d6fe6f0f529a318cae36 Mon Sep 17 00:00:00 2001 From: Duke Wang Date: Fri, 30 Jul 2021 09:50:56 +0800 Subject: [PATCH 17/19] Update python/tvm/topi/cuda/scatter.py Co-authored-by: Tristan Konolige --- python/tvm/topi/cuda/scatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index b211ac0d5364..cd785ca99c6a 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -854,7 +854,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): bdim_x = fused_indices_dimension bdim_y = ceil_div(fused_updates_dimension, tdim) # In case of large input sizes, fused_indices_dimension might be too large. - # So it could be moved to blockIdx.x position, which holds larger scales. + # So we use blockIdx.x because holds larger scales. bx = te.thread_axis("blockIdx.x") by = te.thread_axis("blockIdx.y") tx = te.thread_axis("threadIdx.x") From 7c361c94054cdc09b610296d7f4edb4181921bb3 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Fri, 30 Jul 2021 09:55:14 +0800 Subject: [PATCH 18/19] Assign TODO --- python/tvm/topi/cuda/scatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index cd785ca99c6a..b9f64ca8498e 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -813,7 +813,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64" # So we fallback to normal algorithm, using "+=" rather than atomic_add - # TODO: + # TODO (CaptainDuke): # Since multiple threads compete for the same write index, which leads to # non-determinstic output for update mode. We could add a new attribute # "allow_non_deterministic" to scatter_nd op, which is False by default. From 31fbde5ce2771334a4462b659de2a321304158b2 Mon Sep 17 00:00:00 2001 From: CaptainDuke Date: Sun, 1 Aug 2021 00:26:39 +0800 Subject: [PATCH 19/19] Swapping then and else block --- python/tvm/topi/cuda/scatter.py | 68 +++++++++++++++------------------ 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index b9f64ca8498e..fa7545cd323a 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -764,9 +764,6 @@ def scatter_nd(data, indices, updates, mode): """ _verify_scatter_nd_inputs(data, indices, updates) - def cur_target_kind(kind="cuda"): - return tvm.target.Target.current(allow_none=False).kind == tvm.target.Target(kind).kind - def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = tvm.tir.ir_builder.create() @@ -815,42 +812,16 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # TODO (CaptainDuke): # Since multiple threads compete for the same write index, which leads to - # non-determinstic output for update mode. We could add a new attribute - # "allow_non_deterministic" to scatter_nd op, which is False by default. - # And change ONNX frontend to emit scatter_op with allow_non_deterministic = True, - # which will allow the new code path for update mode as well + # non-determinstic output for update mode. We could add a new attribute, + # "allow_non_deterministic", which can be conditionally set to True by + # each frontend when non-determinsm is allowed. + cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind) with ib.new_scope(): if ( - mode == "update" - or cur_target_kind("vulkan") - or cur_target_kind("metal") - or (updates.dtype == "int64" and mode == "add") + mode == "add" + and cur_target_kind not in ["vulkan", "metal"] + and updates.dtype in ["int32", "float32"] ): - bdim_x = ceil_div(fused_updates_dimension, tdim) - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(bx, "thread_extent", bdim_x) - ib.scope_attr(tx, "thread_extent", tdim) - with ib.for_range(0, fused_indices_dimension) as i: - j = bx * tdim + tx - with ib.if_scope(j < fused_updates_dimension): - offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the - # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}] - # part of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): - # indices[i * l * fused_indices_dimension] = indices[l, y_0, - # ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] - offset *= data_ptr.shape[l] - if mode == "update": - out[index] = updates[i * fused_updates_dimension + j] - elif mode == "add": - out[index] += updates[i * fused_updates_dimension + j] - else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) - elif mode == "add": bdim_x = fused_indices_dimension bdim_y = ceil_div(fused_updates_dimension, tdim) # In case of large input sizes, fused_indices_dimension might be too large. @@ -878,7 +849,30 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates[up_index], ) else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + bdim_x = ceil_div(fused_updates_dimension, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(tx, "thread_extent", tdim) + with ib.for_range(0, fused_indices_dimension) as i: + j = bx * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the + # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, + # ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get()