diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index edcd5b9ce161..e42b8bbae814 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -113,7 +113,7 @@ def compute_scatter_add(attrs, inputs, output_type): return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)] -_reg.register_schedule("scatter_add", strategy.schedule_scatter_add) +_reg.register_strategy("scatter_add", strategy.scatter_add_strategy) ##################### # Shape functions # diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b7ceda304639..26e9a0060b66 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -664,7 +664,7 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): @scatter_strategy.register(["cuda", "gpu"]) def scatter_cuda(attrs, inputs, out_type, target): - """sparse dense cuda strategy""" + """scatter cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_scatter(topi.cuda.scatter), @@ -675,6 +675,19 @@ def scatter_cuda(attrs, inputs, out_type, target): return strategy +@scatter_add_strategy.register(["cuda", "gpu"]) +def scatter_add_cuda(attrs, inputs, out_type, target): + """scatter_add cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter_add), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_add.cuda", + plevel=10, + ) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index bdefbcb79009..e49135c4d1bf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1052,12 +1052,15 @@ def _compute_scatter(attrs, inputs, _): return _compute_scatter -# scatter_add -@generic_func -def schedule_scatter_add(attrs, outs, target): - """schedule scatter_add""" - with target: - return topi.generic.schedule_scatter_add(outs) +@override_native_generic_func("scatter_add_strategy") +def scatter_add_strategy(attrs, outs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.scatter_add), + wrap_topi_schedule(topi.generic.schedule_scatter), + name="scatter_add.generic", + ) + return strategy # bitserial_conv2d diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 6522d74d8bef..0a3e96f4be30 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -24,7 +24,7 @@ def ceil_div(a, b): return (a + b - 1) // b -def gen_ir_1d(data, indices, updates, axis, out): +def gen_ir_1d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 1d inputs Parameters @@ -44,6 +44,9 @@ def gen_ir_1d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update. + Returns ------- ret : tir @@ -73,14 +76,14 @@ def gen_ir_1d(data, indices, updates, axis, out): with ib.for_range(0, ni, name="i") as i: index = indices_ptr[i] with ib.if_scope(index < 0): - out_ptr[index + n] = updates_ptr[i] + update_func(out_ptr, index + n, updates_ptr[i]) with ib.else_scope(): - out_ptr[index] = updates_ptr[i] + update_func(out_ptr, index, updates_ptr[i]) return ib.get() -def gen_ir_2d(data, indices, updates, axis, out): +def gen_ir_2d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 2d inputs Parameters @@ -100,6 +103,9 @@ def gen_ir_2d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update + Returns ------- ret : tir @@ -140,9 +146,9 @@ def gen_ir_2d(data, indices, updates, axis, out): idx = i * ci + j index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(index + n) * c + j] = updates_ptr[idx] + update_func(out_ptr, (index + n) * c + j, updates_ptr[idx]) with ib.else_scope(): - out_ptr[index * c + j] = updates_ptr[idx] + update_func(out_ptr, index * c + j, updates_ptr[idx]) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -151,13 +157,13 @@ def gen_ir_2d(data, indices, updates, axis, out): idx = i * ci + j index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[i * c + (index + c)] = updates_ptr[idx] + update_func(out_ptr, i * c + (index + c), updates_ptr[idx]) with ib.else_scope(): - out_ptr[i * c + index] = updates_ptr[idx] + update_func(out_ptr, i * c + index, updates_ptr[idx]) return ib.get() -def gen_ir_3d(data, indices, updates, axis, out): +def gen_ir_3d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 3d inputs Parameters @@ -177,6 +183,9 @@ def gen_ir_3d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update + Returns ------- ret : tir @@ -225,9 +234,9 @@ def gen_ir_3d(data, indices, updates, axis, out): idx = (i * ci + j) * hi + k index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((index + n) * c + j) * h + k] = updates_ptr[idx] + update_func(out_ptr, ((index + n) * c + j) * h + k, updates_ptr[idx]) with ib.else_scope(): - out_ptr[(index * c + j) * h + k] = updates_ptr[idx] + update_func(out_ptr, (index * c + j) * h + k, updates_ptr[idx]) elif axis == 1: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -241,9 +250,9 @@ def gen_ir_3d(data, indices, updates, axis, out): idx = (i * ci + j) * hi + k index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[idx] + update_func(out_ptr, (i * c + (index + c)) * h + k, updates_ptr[idx]) with ib.else_scope(): - out_ptr[(i * c + index) * h + k] = updates_ptr[idx] + update_func(out_ptr, (i * c + index) * h + k, updates_ptr[idx]) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -254,13 +263,13 @@ def gen_ir_3d(data, indices, updates, axis, out): idx = (i * ci + j) * hi + k index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[idx] + update_func(out_ptr, (i * c + j) * h + (index + h), updates_ptr[idx]) with ib.else_scope(): - out_ptr[(i * c + j) * h + index] = updates_ptr[idx] + update_func(out_ptr, (i * c + j) * h + index, updates_ptr[idx]) return ib.get() -def gen_ir_4d(data, indices, updates, axis, out): +def gen_ir_4d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 4d inputs Parameters @@ -280,6 +289,9 @@ def gen_ir_4d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update + Returns ------- ret : tir @@ -333,9 +345,13 @@ def gen_ir_4d(data, indices, updates, axis, out): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, (((index + n) * c + j) * h + k) * w + l, updates_ptr[idx] + ) with ib.else_scope(): - out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((index * c + j) * h + k) * w + l, updates_ptr[idx] + ) elif axis == 1: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -351,9 +367,13 @@ def gen_ir_4d(data, indices, updates, axis, out): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + (index + c)) * h + k) * w + l, updates_ptr[idx] + ) with ib.else_scope(): - out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + index) * h + k) * w + l, updates_ptr[idx] + ) elif axis == 2: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -369,9 +389,13 @@ def gen_ir_4d(data, indices, updates, axis, out): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + j) * h + (index + h)) * w + l, updates_ptr[idx] + ) with ib.else_scope(): - out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + j) * h + index) * w + l, updates_ptr[idx] + ) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -384,10 +408,9 @@ def gen_ir_4d(data, indices, updates, axis, out): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[idx] + update_func(out_ptr, ((i * c + j) * h + k) * w + (index + w), updates_ptr[idx]) with ib.else_scope(): - out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[idx] - + update_func(out_ptr, ((i * c + j) * h + k) * w + index, updates_ptr[idx]) return ib.get() @@ -428,12 +451,15 @@ def scatter(data, indices, updates, axis=0): 4: gen_ir_4d, } + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + out_shape = data.shape out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") out = te.extern( [out_shape], [data, indices, updates], - lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0]), + lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), dtype=data.dtype, out_buffers=[out_buf], name="scatter_gpu", @@ -441,3 +467,58 @@ def scatter(data, indices, updates, axis=0): ) return out + + +def scatter_add(data, indices, updates, axis=0): + """Update data by adding values in updates at positions defined by indices + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to be added. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + rank = len(data.shape) + assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions" + + ir_funcs = { + 1: gen_ir_1d, + 2: gen_ir_2d, + 3: gen_ir_3d, + 4: gen_ir_4d, + } + + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + out_shape = data.shape + out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") + out = te.extern( + [out_shape], + [data, indices, updates], + lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_add_gpu", + tag="scatter_add_gpu", + ) + + return out diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4dec5f7e5916..6250dfff811a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3149,17 +3149,17 @@ def test_fn_scatter_add(dim): in_data = torch.zeros(3, 5) in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) in_src = torch.rand(2, 5) - # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"]) - verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"]) + + targets = ["llvm", "cuda"] + verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets) + verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets) in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) - # # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"]) - verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], targets) + verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets) def test_numel(): diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 3ea0777df8ca..01de6d265cb2 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -961,6 +961,7 @@ def verify_dynamic_scatter(dshape, ishape, axis=0): verify_dynamic_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) +@tvm.testing.uses_gpu def test_scatter_add(): def ref_scatter_add(data, indices, updates, axis=0): output = np.copy(data) @@ -983,8 +984,7 @@ def verify_scatter_add(dshape, ishape, axis=0): indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) - # TODO(mbrookhart): expand testing when adding more backend schedules - for target, ctx in [("llvm", tvm.cpu())]: + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)