diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index c697b648786e..fa7545cd323a 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -772,9 +772,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - # We combine all the indices dimensions but the first one into a single - # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the update dimensions. + atomic_add_return = ib.allocate( + updates.dtype, (1,), name="atomic_add_return", scope="local" + ) + fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i @@ -787,44 +788,91 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): for i in data_ptr.shape: fused_shape *= i - # For now we avoid parallizing over dimensions indexed by `indices` as - # there may be repeated indices and hadling parallel accumulation can - # be hard. So we parallelize over X_M .. X_{N-1} instead. This will - # work well when these dimensions are large enough to saturate memory - # bandwidth, but performance will be bad when these dimensions are - # small. - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) tdim = min(max_threads, fused_updates_dimension) - ib.scope_attr(tx, "thread_extent", tdim) - bdim = ceil_div(fused_updates_dimension, tdim) - ib.scope_attr(bx, "thread_extent", bdim) - - # Copy data into the output. This loop writes to the same portions of - # memory as the following loop, so we do not need a memory sync. - with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i: - index = i * fused_updates_dimension + bx * tdim + tx - with ib.if_scope(bx * tdim + tx < fused_updates_dimension): + + with ib.new_scope(): + bdim = ceil_div(fused_shape, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim) + ib.scope_attr(tx, "thread_extent", tdim) + + index = bx * tdim + tx + with ib.if_scope(index < fused_shape): out[index] = data[index] - with ib.for_range(0, fused_indices_dimension) as i: - j = bx * tdim + tx - with ib.if_scope(j < fused_updates_dimension): - offset = fused_updates_dimension - index = j # This is x_M, .. x_{N-1} part of the index into out. - # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part - # of the index into out. - for l in reversed(range(indices_ptr.shape[0].value)): - # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] - index += offset * indices[i + l * fused_indices_dimension] - offset *= data_ptr.shape[l] - if mode == "update": - out[index] = updates[i * fused_updates_dimension + j] - elif mode == "add": - out[index] += updates[i * fused_updates_dimension + j] - else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + # For better performance, we introduce blockIdx.y to implement for-loops + # within one thread. + # The code is parallel over the scattered indices, so we use atomic_add + # to guarantee correctness when mode=="add" + + # For now, atomic is not supported by target "vulkan", "metal", or "cuda" with "int64" + # So we fallback to normal algorithm, using "+=" rather than atomic_add + + # TODO (CaptainDuke): + # Since multiple threads compete for the same write index, which leads to + # non-determinstic output for update mode. We could add a new attribute, + # "allow_non_deterministic", which can be conditionally set to True by + # each frontend when non-determinsm is allowed. + cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind) + with ib.new_scope(): + if ( + mode == "add" + and cur_target_kind not in ["vulkan", "metal"] + and updates.dtype in ["int32", "float32"] + ): + bdim_x = fused_indices_dimension + bdim_y = ceil_div(fused_updates_dimension, tdim) + # In case of large input sizes, fused_indices_dimension might be too large. + # So we use blockIdx.x because holds larger scales. + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(by, "thread_extent", bdim_y) + ib.scope_attr(tx, "thread_extent", tdim) + + j = by * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. + up_index = bx * fused_updates_dimension + j + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[bx * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] + index += offset * indices[bx + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out[index]), + updates[up_index], + ) + else: + bdim_x = ceil_div(fused_updates_dimension, tdim) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", bdim_x) + ib.scope_attr(tx, "thread_extent", tdim) + with ib.for_range(0, fused_indices_dimension) as i: + j = bx * tdim + tx + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension + index = j # This is x_M, .. x_{N-1} part of the index into out. + # Build up the + # indices[0, y_0, .. y_{K-1}], ... indices[M-1, y_0, .. y_{K-1}] + # part of the index into out. + for l in reversed(range(indices_ptr.shape[0].value)): + # indices[i * l * fused_indices_dimension] = indices[l, y_0, + # ... y_{k-1}] + index += offset * indices[i + l * fused_indices_dimension] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fc67f0b90295..96a7a7a95e49 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1884,7 +1884,8 @@ def verify_scatter_nd_with_stack( ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices_vars = [ - relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) + relay.var("ind%d" % i, shape=v.shape, dtype=str(v.dtype)) + for i, v in enumerate(indices_np) ] updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))