Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
# there may be repeated indices and hadling parallel accumulation can
# there may be repeated indices and handling parallel accumulation can
# be hard. So we parallelize over X_M .. X_{N-1} instead. This will
# work well when these dimensions are large enough to saturate memory
# bandwidth, but performance will be bad when these dimensions are
Expand All @@ -801,12 +801,14 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
# Copy data into the output. This loop writes to the same portions of
# memory as the following loop, so we do not need a memory sync.
with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i:
index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(index < fused_shape):
with ib.if_scope(bx * tdim + tx < fused_updates_dimension):
out[index] = data[index]

with ib.for_range(0, fused_indices_dimension) as i:
with ib.for_range(0, fused_indices_dimension, name="i") as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _verify_scatter_nd_inputs(data, indices, updates):
mdim = int(indices.shape[0])
assert mdim <= len(data.shape), (
f"The first dimension of the indices ({mdim}) must be less than or equal to "
f"the length of the shape of the output ({len(shape)})."
f"the length of the shape of the output ({len(data.shape)})."
)
for i in range(len(indices.shape) - 1):
if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var):
Expand Down