From 4f0085aab527b5391bba54b946809fd71a113d80 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 22 Sep 2020 15:27:18 -0700 Subject: [PATCH 1/4] working cuda scatter fix lint fix pylint again --- python/tvm/relay/op/_transform.py | 2 +- python/tvm/relay/op/strategy/cuda.py | 13 + python/tvm/relay/op/strategy/generic.py | 23 +- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/scatter.py | 374 ++++++++++++++++++++++++ tests/python/relay/test_op_level3.py | 5 +- 6 files changed, 410 insertions(+), 8 deletions(-) create mode 100644 python/tvm/topi/cuda/scatter.py diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 415529fdcb9a..28b8f5c01b9a 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -104,7 +104,7 @@ def compute_scatter(attrs, inputs, output_type): return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)] -_reg.register_schedule("scatter", strategy.schedule_scatter) +_reg.register_strategy("scatter", strategy.scatter_strategy) # scatter_add @_reg.register_compute("scatter_add") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index ca44e49ce1dd..d77361d906fb 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -651,6 +651,19 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): return strategy +@scatter_strategy.register(["cuda", "gpu"]) +def scatter_cuda(attrs, inputs, out_type, target): + """sparse dense cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter.cuda", + plevel=10, + ) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 0f9971012f3c..8933c38549b8 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1032,11 +1032,24 @@ def schedule_argwhere(attrs, outs, target): # scatter -@generic_func -def schedule_scatter(attrs, outs, target): - """schedule scatter""" - with target: - return topi.generic.schedule_scatter(outs) +@override_native_generic_func("scatter_strategy") +def scatter_strategy(attrs, outs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.scatter), + wrap_topi_schedule(topi.generic.schedule_scatter), + name="scatter.generic", + ) + return strategy + + +def wrap_compute_scatter(topi_compute): + """Wrap scatter topi compute""" + + def _compute_scatter(attrs, inputs, _): + return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)] + + return _compute_scatter # scatter_add diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index ed8037024635..3ff544f4bb3e 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -46,6 +46,7 @@ from .ssd import * from .nms import get_valid_counts, non_max_suppression from .rcnn import * +from .scatter import * from .sort import * from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py new file mode 100644 index 000000000000..328f00186d79 --- /dev/null +++ b/python/tvm/topi/cuda/scatter.py @@ -0,0 +1,374 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument +"""Scatter operator """ +import tvm +from tvm import te + + +def gen_ir_1d(data, indices, updates, axis, out): + """Generate scatter ir for 1d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + assert axis == 0 + n = data.shape[0] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", 1) + + with ib.for_range(0, n, name="i") as i: + out_ptr[i] = data_ptr[i] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] + + with ib.for_range(0, ni, name="i") as i: + index = indices_ptr[i] + with ib.if_scope(index < 0): + out_ptr[index + n] = updates_ptr[i] + with ib.else_scope(): + out_ptr[index] = updates_ptr[i] + + return ib.get() + + +def gen_ir_2d(data, indices, updates, axis, out): + """Generate scatter ir for 2d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + n = data.shape[0] + c = data.shape[1] + + ib = tvm.tir.ir_builder.create() + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", 1) + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, c, name="j") as j: + out_ptr[i * c + j] = data_ptr[i * c + j] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] + ci = indices.shape[1] + + if axis == 0: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + index = indices_ptr[i * ci + j] + with ib.if_scope(index < 0): + out_ptr[(index + n) * c + j] = updates_ptr[i * ci + j] + with ib.else_scope(): + out_ptr[index * c + j] = updates_ptr[i * ci + j] + else: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + index = indices_ptr[i * ci + j] + with ib.if_scope(index < 0): + out_ptr[i * c + (index + c)] = updates_ptr[i * ci + j] + with ib.else_scope(): + out_ptr[i * c + index] = updates_ptr[i * ci + j] + + return ib.get() + + +def gen_ir_3d(data, indices, updates, axis, out): + """Generate scatter ir for 3d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + n = data.shape[0] + c = data.shape[1] + h = data.shape[2] + + ib = tvm.tir.ir_builder.create() + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", 1) + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, c, name="j") as j: + with ib.for_range(0, h, name="k") as k: + out_ptr[(i * c + j) * h + k] = data_ptr[(i * c + j) * h + k] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] + ci = indices.shape[1] + hi = indices.shape[2] + + if axis == 0: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, hi, name="k") as k: + index = indices_ptr[(i * ci + j) * hi + k] + with ib.if_scope(index < 0): + out_ptr[((index + n) * c + j) * h + k] = updates_ptr[(i * ci + j) * hi + k] + with ib.else_scope(): + out_ptr[(index * c + j) * h + k] = updates_ptr[(i * ci + j) * hi + k] + elif axis == 1: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, hi, name="k") as k: + index = indices_ptr[(i * ci + j) * hi + k] + with ib.if_scope(index < 0): + out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[(i * ci + j) * hi + k] + with ib.else_scope(): + out_ptr[(i * c + index) * h + k] = updates_ptr[(i * ci + j) * hi + k] + else: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, hi, name="k") as k: + index = indices_ptr[(i * ci + j) * hi + k] + with ib.if_scope(index < 0): + out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[(i * ci + j) * hi + k] + with ib.else_scope(): + out_ptr[(i * c + j) * h + index] = updates_ptr[(i * ci + j) * hi + k] + + return ib.get() + + +def gen_ir_4d(data, indices, updates, axis, out): + """Generate scatter ir for 4d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + n = data.shape[0] + c = data.shape[1] + h = data.shape[2] + w = data.shape[3] + + ib = tvm.tir.ir_builder.create() + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", 1) + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, c, name="j") as j: + with ib.for_range(0, h, name="k") as k: + with ib.for_range(0, w, name="l") as l: + out_ptr[((i * c + j) * h + k) * w + l] = data_ptr[((i * c + j) * h + k) * w + l] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] + ci = indices.shape[1] + hi = indices.shape[2] + wi = indices.shape[3] + + if axis == 0: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, hi, name="k") as k: + with ib.for_range(0, wi, name="l") as l: + index = indices_ptr[((i * ci + j) * hi + k) * wi + l] + with ib.if_scope(index < 0): + out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + with ib.else_scope(): + out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + elif axis == 1: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, hi, name="k") as k: + with ib.for_range(0, wi, name="l") as l: + index = indices_ptr[((i * ci + j) * hi + k) * wi + l] + with ib.if_scope(index < 0): + out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + with ib.else_scope(): + out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + elif axis == 2: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, hi, name="k") as k: + with ib.for_range(0, wi, name="l") as l: + index = indices_ptr[((i * ci + j) * hi + k) * wi + l] + with ib.if_scope(index < 0): + out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + with ib.else_scope(): + out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + else: + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, hi, name="k") as k: + with ib.for_range(0, wi, name="l") as l: + index = indices_ptr[((i * ci + j) * hi + k) * wi + l] + with ib.if_scope(index < 0): + out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + with ib.else_scope(): + out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[ + ((i * ci + j) * hi + k) * wi + l + ] + + return ib.get() + + +def scatter(data, indices, updates, axis=0): + """Update data at positions defined by indices with values in updates + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + rank = len(data.shape) + assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions" + + ir_funcs = { + 1: gen_ir_1d, + 2: gen_ir_2d, + 3: gen_ir_3d, + 4: gen_ir_4d, + } + + out_shape = data.shape + out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") + out = te.extern( + [out_shape], + [data, indices, updates], + lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_gpu", + tag="scatter_gpu", + ) + + return out diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index b01977759a4b..9477703edcd6 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -878,6 +878,7 @@ def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res): ) +@tvm.testing.uses_gpu def test_scatter(): def ref_scatter(data, indices, updates, axis=0): idx = np.indices(indices.shape).reshape(indices.ndim, -1) @@ -903,8 +904,8 @@ def verify_scatter(dshape, ishape, axis=0): indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") ref_res = ref_scatter(data_np, indices_np, updates_np, axis) - # TODO(mbrookhart): expand testing when adding more backend schedules - for target, ctx in [("llvm", tvm.cpu())]: + + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) From 37eeeae354a8c40ff52b73ec0d60372f4429c80d Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 24 Sep 2020 12:14:30 -0700 Subject: [PATCH 2/4] cuda scatter with threading --- python/tvm/topi/cuda/scatter.py | 286 +++++++++++++++++---------- tests/python/relay/test_op_level3.py | 1 + 2 files changed, 179 insertions(+), 108 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 328f00186d79..c5c596287e9d 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -20,6 +20,10 @@ from tvm import te +def ceil_div(a, b): + return (a + b - 1) // b + + def gen_ir_1d(data, indices, updates, axis, out): """Generate scatter ir for 1d inputs @@ -53,22 +57,25 @@ def gen_ir_1d(data, indices, updates, axis, out): out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", 1) - - with ib.for_range(0, n, name="i") as i: - out_ptr[i] = data_ptr[i] + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", n) + out_ptr[bx] = data_ptr[bx] indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] - with ib.for_range(0, ni, name="i") as i: - index = indices_ptr[i] - with ib.if_scope(index < 0): - out_ptr[index + n] = updates_ptr[i] - with ib.else_scope(): - out_ptr[index] = updates_ptr[i] + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", 1) + with ib.for_range(0, ni, name="i") as i: + index = indices_ptr[i] + with ib.if_scope(index < 0): + out_ptr[index + n] = updates_ptr[i] + with ib.else_scope(): + out_ptr[index] = updates_ptr[i] return ib.get() @@ -98,41 +105,56 @@ def gen_ir_2d(data, indices, updates, axis, out): ret : tir The computational ir. """ + warp_size = tvm.target.Target.current(False).thread_warp_size + n = data.shape[0] c = data.shape[1] ib = tvm.tir.ir_builder.create() - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", 1) out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - with ib.for_range(0, n, name="i") as i: - with ib.for_range(0, c, name="j") as j: - out_ptr[i * c + j] = data_ptr[i * c + j] + + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", n) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + i = bx + with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_: + j = j_ * warp_size + tx + with ib.if_scope(j < c): + idx = bx * c + j + out_ptr[idx] = data_ptr[idx] indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] ci = indices.shape[1] if axis == 0: - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ci, name="j") as j: - index = indices_ptr[i * ci + j] + with ib.new_scope(): + j = te.thread_axis("blockIdx.x") + ib.scope_attr(j, "thread_extent", ci) + with ib.for_range(0, ni, name="i") as i: + idx = i * ci + j + index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(index + n) * c + j] = updates_ptr[i * ci + j] + out_ptr[(index + n) * c + j] = updates_ptr[idx] with ib.else_scope(): - out_ptr[index * c + j] = updates_ptr[i * ci + j] + out_ptr[index * c + j] = updates_ptr[idx] else: - with ib.for_range(0, ni, name="i") as i: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) with ib.for_range(0, ci, name="j") as j: - index = indices_ptr[i * ci + j] + idx = i * ci + j + index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[i * c + (index + c)] = updates_ptr[i * ci + j] + out_ptr[i * c + (index + c)] = updates_ptr[idx] with ib.else_scope(): - out_ptr[i * c + index] = updates_ptr[i * ci + j] - + out_ptr[i * c + index] = updates_ptr[idx] return ib.get() @@ -161,20 +183,29 @@ def gen_ir_3d(data, indices, updates, axis, out): ret : tir The computational ir. """ + warp_size = tvm.target.Target.current(False).thread_warp_size + n = data.shape[0] c = data.shape[1] h = data.shape[2] ib = tvm.tir.ir_builder.create() - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", 1) out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - with ib.for_range(0, n, name="i") as i: - with ib.for_range(0, c, name="j") as j: - with ib.for_range(0, h, name="k") as k: - out_ptr[(i * c + j) * h + k] = data_ptr[(i * c + j) * h + k] + + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", n) + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", c) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_: + k = k_ * warp_size + tx + with ib.if_scope(k < h): + idx = (bx * c + by) * h + k + out_ptr[idx] = data_ptr[idx] indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) @@ -183,33 +214,50 @@ def gen_ir_3d(data, indices, updates, axis, out): hi = indices.shape[2] if axis == 0: - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, hi, name="k") as k: - index = indices_ptr[(i * ci + j) * hi + k] - with ib.if_scope(index < 0): - out_ptr[((index + n) * c + j) * h + k] = updates_ptr[(i * ci + j) * hi + k] - with ib.else_scope(): - out_ptr[(index * c + j) * h + k] = updates_ptr[(i * ci + j) * hi + k] + with ib.new_scope(): + j = te.thread_axis("blockIdx.x") + ib.scope_attr(j, "thread_extent", ci) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_: + k = k_ * warp_size + tx + with ib.if_scope(k < hi): + idx = (i * ci + j) * hi + k + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[((index + n) * c + j) * h + k] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[(index * c + j) * h + k] = updates_ptr[idx] elif axis == 1: - with ib.for_range(0, ni, name="i") as i: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, hi, name="k") as k: - index = indices_ptr[(i * ci + j) * hi + k] - with ib.if_scope(index < 0): - out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[(i * ci + j) * hi + k] - with ib.else_scope(): - out_ptr[(i * c + index) * h + k] = updates_ptr[(i * ci + j) * hi + k] + with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_: + k = k_ * warp_size + tx + with ib.if_scope(k < hi): + idx = (i * ci + j) * hi + k + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[(i * c + index) * h + k] = updates_ptr[idx] else: - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, hi, name="k") as k: - index = indices_ptr[(i * ci + j) * hi + k] - with ib.if_scope(index < 0): - out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[(i * ci + j) * hi + k] - with ib.else_scope(): - out_ptr[(i * c + j) * h + index] = updates_ptr[(i * ci + j) * hi + k] - + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + with ib.for_range(0, hi, name="k") as k: + idx = (i * ci + j) * hi + k + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[(i * c + j) * h + index] = updates_ptr[idx] return ib.get() @@ -238,22 +286,31 @@ def gen_ir_4d(data, indices, updates, axis, out): ret : tir The computational ir. """ + warp_size = tvm.target.Target.current(False).thread_warp_size + n = data.shape[0] c = data.shape[1] h = data.shape[2] w = data.shape[3] ib = tvm.tir.ir_builder.create() - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", 1) out_ptr = ib.buffer_ptr(out) data_ptr = ib.buffer_ptr(data) - with ib.for_range(0, n, name="i") as i: - with ib.for_range(0, c, name="j") as j: - with ib.for_range(0, h, name="k") as k: - with ib.for_range(0, w, name="l") as l: - out_ptr[((i * c + j) * h + k) * w + l] = data_ptr[((i * c + j) * h + k) * w + l] + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", n) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", c) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", h) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ceil_div(w, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < w): + idx = ((i * c + j) * h + k) * w + l + out_ptr[idx] = data_ptr[idx] indices_ptr = ib.buffer_ptr(indices) updates_ptr = ib.buffer_ptr(updates) @@ -263,61 +320,74 @@ def gen_ir_4d(data, indices, updates, axis, out): wi = indices.shape[3] if axis == 0: - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, hi, name="k") as k: - with ib.for_range(0, wi, name="l") as l: - index = indices_ptr[((i * ci + j) * hi + k) * wi + l] + with ib.new_scope(): + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", hi) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < wi): + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] + out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[idx] with ib.else_scope(): - out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] + out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[idx] elif axis == 1: - with ib.for_range(0, ni, name="i") as i: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", hi) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, hi, name="k") as k: - with ib.for_range(0, wi, name="l") as l: - index = indices_ptr[((i * ci + j) * hi + k) * wi + l] + with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < wi): + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] + out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[idx] with ib.else_scope(): - out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] + out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[idx] elif axis == 2: - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, hi, name="k") as k: - with ib.for_range(0, wi, name="l") as l: - index = indices_ptr[((i * ci + j) * hi + k) * wi + l] + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, hi, name="k") as k: + with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < wi): + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] + out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[idx] with ib.else_scope(): - out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] + out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[idx] else: - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, hi, name="k") as k: - with ib.for_range(0, wi, name="l") as l: - index = indices_ptr[((i * ci + j) * hi + k) * wi + l] - with ib.if_scope(index < 0): - out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] - with ib.else_scope(): - out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[ - ((i * ci + j) * hi + k) * wi + l - ] + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", hi) + with ib.for_range(0, wi, name="l") as l: + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[idx] return ib.get() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 9477703edcd6..bcfcdc06cbf6 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -919,6 +919,7 @@ def verify_scatter(dshape, ishape, axis=0): verify_scatter((2, 3, 4), (1, 3, 4), 0) verify_scatter((2, 3, 4), (2, 1, 4), 1) verify_scatter((2, 3, 4), (2, 3, 1), 2) + verify_scatter((4, 2, 1), (1, 1, 1), 0) verify_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0) verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1) verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2) From 489d00c52d780553e82039f762fb3f6e3ef2decf Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 8 Oct 2020 09:31:02 -0600 Subject: [PATCH 3/4] add dynamic shape tests --- tests/python/relay/test_op_level3.py | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index bcfcdc06cbf6..3ea0777df8ca 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -911,6 +911,27 @@ def verify_scatter(dshape, ishape, axis=0): op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + def verify_dynamic_scatter(dshape, ishape, axis=0): + d = relay.var("d", relay.TensorType([relay.Any() for i in range(len(dshape))], "float32")) + i = relay.var("i", relay.TensorType([relay.Any() for i in range(len(ishape))], "int64")) + u = relay.var("u", relay.TensorType([relay.Any() for i in range(len(ishape))], "float32")) + z = relay.op.scatter(d, i, u, axis) + + func = relay.Function([d, i, u], z) + + data_np = np.random.uniform(size=dshape).astype("float32") + updates_np = np.random.uniform(size=ishape).astype("float32") + indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") + + ref_res = ref_scatter(data_np, indices_np, updates_np, axis) + + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data_np, indices_np, updates_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_scatter((10,), (10,), 0) verify_scatter((10, 5), (10, 5), -2) verify_scatter((10, 5), (10, 5), -1) @@ -925,6 +946,20 @@ def verify_scatter(dshape, ishape, axis=0): verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2) verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) + verify_dynamic_scatter((10,), (10,), 0) + verify_dynamic_scatter((10, 5), (10, 5), -2) + verify_dynamic_scatter((10, 5), (10, 5), -1) + verify_dynamic_scatter((10, 5), (3, 5), 0) + verify_dynamic_scatter((12, 4), (7, 2), 1) + verify_dynamic_scatter((2, 3, 4), (1, 3, 4), 0) + verify_dynamic_scatter((2, 3, 4), (2, 1, 4), 1) + verify_dynamic_scatter((2, 3, 4), (2, 3, 1), 2) + verify_dynamic_scatter((4, 2, 1), (1, 1, 1), 0) + verify_dynamic_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0) + verify_dynamic_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1) + verify_dynamic_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2) + verify_dynamic_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) + def test_scatter_add(): def ref_scatter_add(data, indices, updates, axis=0): From 03de673a35de918004924d80c19a6d55a45cecbf Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 19 Oct 2020 14:43:17 -0600 Subject: [PATCH 4/4] remove unused variable --- python/tvm/topi/cuda/scatter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index c5c596287e9d..6522d74d8bef 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -120,7 +120,6 @@ def gen_ir_2d(data, indices, updates, axis, out): ib.scope_attr(bx, "thread_extent", n) tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", warp_size) - i = bx with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_: j = j_ * warp_size + tx with ib.if_scope(j < c):