From 01cf7303142680efdf1e3582965c19513685afbe Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 6 Dec 2020 21:58:33 +0900 Subject: [PATCH 1/5] use atomic add for faster 1d scatter add --- python/tvm/relay/frontend/pytorch.py | 16 ++++-- python/tvm/topi/cuda/scatter.py | 81 +++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 4f75cf380cc6..015b7612e0ff 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1921,18 +1921,28 @@ def empty(self, inputs, input_types): def bincount(self, inputs, input_types): data = inputs[0] weights = inputs[1] + input_type = _infer_type(data).checked_type.dtype + if input_type == "int64": + logging.warning( + "Casting an int64 input to int32, since we do not have int64 atomic add needed for bincount yet." + ) + data = _op.cast(data, "int32") maximum = _op.max(data) - dim = maximum + _expr.const(1, dtype="int64") + dim = maximum + _expr.const(1, dtype="int32") if weights: weight_type = _infer_type(weights).checked_type out_dtype = weight_type.dtype updates = weights else: - out_dtype = "int64" + out_dtype = "int32" updates = _op.ones_like(data) counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) - return _op.scatter_add(counts, data, updates, axis=0) + out = _op.scatter_add(counts, data, updates, axis=0) + if input_type == "int32": + # Torch always outputs int64 results for bincount + return _op.cast(out, "int64") + return out def scatter_add(self, inputs, input_types): data = inputs[0] diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 5e03fafcfb58..06535832247e 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -19,6 +19,7 @@ import tvm from tvm import te from ..scatter import _verify_scatter_nd_inputs +from .nms import atomic_add def ceil_div(a, b): @@ -470,6 +471,84 @@ def update_func(dst_ptr, dst_index, update): return out +def gen_scatter_add_1d(data, indices, updates, axis, out, _): + """Generate scatter ir for 1d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + update_func: function + The function to be applied to a destination and the corresponding update. + + Returns + ------- + ret : tir + The computational ir. + """ + assert axis == 0 + n = data.shape[0] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + with ib.new_scope(): + nthread_bx = ceil_div(n, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + with ib.if_scope(tid < n): + out_ptr[tid] = data_ptr[tid] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + + ni = indices.shape[0] + + atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") + + with ib.new_scope(): + nthread_bx = ceil_div(ni, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + with ib.if_scope(tid < ni): + index = indices_ptr[tid] + with ib.if_scope(index < 0): + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]), updates_ptr[tid] + ) + with ib.else_scope(): + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]), updates_ptr[tid] + ) + + return ib.get() + + def scatter_add(data, indices, updates, axis=0): """Update data by adding values in updates at positions defined by indices @@ -501,7 +580,7 @@ def scatter_add(data, indices, updates, axis=0): assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions" ir_funcs = { - 1: gen_ir_1d, + 1: gen_scatter_add_1d, 2: gen_ir_2d, 3: gen_ir_3d, 4: gen_ir_4d, From a1c7910b45de238df76ca5d2b1710446ca7311c1 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 7 Dec 2020 18:47:47 +0900 Subject: [PATCH 2/5] update tests --- python/tvm/topi/cuda/scatter.py | 9 +++------ tests/python/frontend/pytorch/test_forward.py | 11 ++++++----- tests/python/relay/test_op_level3.py | 4 ++++ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 06535832247e..5516a0444f8e 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -471,8 +471,8 @@ def update_func(dst_ptr, dst_index, update): return out -def gen_scatter_add_1d(data, indices, updates, axis, out, _): - """Generate scatter ir for 1d inputs +def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): + """Generate scatter add ir for 1d inputs, using atomic_add instruction Parameters ---------- @@ -491,9 +491,6 @@ def gen_scatter_add_1d(data, indices, updates, axis, out, _): out : tir.Tensor The output tensor. - update_func: function - The function to be applied to a destination and the corresponding update. - Returns ------- ret : tir @@ -580,7 +577,7 @@ def scatter_add(data, indices, updates, axis=0): assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions" ir_funcs = { - 1: gen_scatter_add_1d, + 1: gen_scatter_add_1d_atomic, 2: gen_ir_2d, 3: gen_ir_3d, 4: gen_ir_4d, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6250dfff811a..3f82f6fe1371 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3355,12 +3355,13 @@ def test_bincount(): def test_fn(x, weights=None): return torch.bincount(x, weights=weights) - inp = torch.randint(0, 8, (5,), dtype=torch.int64) - weights = torch.linspace(0, 1, steps=5) + inp = torch.randint(0, 100, (10000,), dtype=torch.int64) + weights = torch.linspace(0, 100, steps=10000) - verify_trace_model(test_fn, [inp], ["llvm"]) - verify_trace_model(test_fn, [inp, weights], ["llvm"]) - verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"]) + targets = ["llvm", "cuda"] + verify_trace_model(test_fn, [inp], targets) + verify_trace_model(test_fn, [inp, weights], targets) + verify_trace_model(test_fn, [inp, weights.to(torch.float64)], targets) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 82d056381666..fc1929e9dc18 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1017,11 +1017,15 @@ def verify_scatter_add(dshape, ishape, axis=0): ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: + if target == "nvptx": + # TODO(masahi): support atomic in LLVM codegen + continue intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) verify_scatter_add((10,), (10,), 0) + verify_scatter_add((1000,), (1000,), 0) verify_scatter_add((10, 5), (10, 5), -2) verify_scatter_add((10, 5), (10, 5), -1) verify_scatter_add((10, 5), (3, 5), 0) From 92f42553fb433267a6fbedc25e5ad2788dac8797 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 7 Dec 2020 18:53:55 +0900 Subject: [PATCH 3/5] run black --- python/tvm/topi/cuda/scatter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 5516a0444f8e..89c5cd23111b 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -536,11 +536,13 @@ def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): index = indices_ptr[tid] with ib.if_scope(index < 0): atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]), updates_ptr[tid] + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]), + updates_ptr[tid], ) with ib.else_scope(): atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]), updates_ptr[tid] + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]), + updates_ptr[tid], ) return ib.get() From 3e22d039d42f846c66cb922222a2eeba544d0098 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 7 Dec 2020 18:58:06 +0900 Subject: [PATCH 4/5] more pylint fix --- python/tvm/relay/frontend/pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 015b7612e0ff..d2c52fbc262a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1924,7 +1924,8 @@ def bincount(self, inputs, input_types): input_type = _infer_type(data).checked_type.dtype if input_type == "int64": logging.warning( - "Casting an int64 input to int32, since we do not have int64 atomic add needed for bincount yet." + "Casting an int64 input to int32, since we do not have int64 atomic add" + "needed for bincount yet." ) data = _op.cast(data, "int32") maximum = _op.max(data) From 4ffb294e790bb41499ca6eba54d6c51652c3e2ff Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 7 Dec 2020 22:43:35 +0900 Subject: [PATCH 5/5] remove fp64 bintcount test --- tests/python/frontend/pytorch/test_forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3f82f6fe1371..2dda675c74f5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3361,7 +3361,6 @@ def test_fn(x, weights=None): targets = ["llvm", "cuda"] verify_trace_model(test_fn, [inp], targets) verify_trace_model(test_fn, [inp, weights], targets) - verify_trace_model(test_fn, [inp, weights.to(torch.float64)], targets) if __name__ == "__main__":