diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index cee13d7e01a2..652a109abe7d 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -788,7 +788,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): fused_shape *= i # For now we avoid parallizing over dimensions indexed by `indices` as - # there may be repeated indices and hadling parallel accumulation can + # there may be repeated indices and handling parallel accumulation can # be hard. So we parallelize over X_M .. X_{N-1} instead. This will # work well when these dimensions are large enough to saturate memory # bandwidth, but performance will be bad when these dimensions are @@ -801,12 +801,14 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): bdim = ceil_div(fused_updates_dimension, tdim) ib.scope_attr(bx, "thread_extent", bdim) - with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: + # Copy data into the output. This loop writes to the same portions of + # memory as the following loop, so we do not need a memory sync. + with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i: index = i * fused_updates_dimension + bx * tdim + tx - with ib.if_scope(index < fused_shape): + with ib.if_scope(bx * tdim + tx < fused_updates_dimension): out[index] = data[index] - with ib.for_range(0, fused_indices_dimension) as i: + with ib.for_range(0, fused_indices_dimension, name="i") as i: j = bx * tdim + tx with ib.if_scope(j < fused_updates_dimension): offset = fused_updates_dimension diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 0fe29f315b43..afb0d6633a2b 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -203,7 +203,7 @@ def _verify_scatter_nd_inputs(data, indices, updates): mdim = int(indices.shape[0]) assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " - f"the length of the shape of the output ({len(shape)})." + f"the length of the shape of the output ({len(data.shape)})." ) for i in range(len(indices.shape) - 1): if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var):