From 8b5cffae7bcd42589bb7b1e8fbb8a5996670a876 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 26 Apr 2021 15:22:15 -0600 Subject: [PATCH 1/6] passing topi tests --- python/tvm/topi/cuda/scatter.py | 48 ++++++------ python/tvm/topi/scatter.py | 77 ++++++++++--------- python/tvm/topi/x86/scatter.py | 49 ++++++------ tests/python/topi/python/test_topi_scatter.py | 70 +++++++++++------ 4 files changed, 139 insertions(+), 105 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index fd05904ba8e7..e54d5821341c 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -18,7 +18,6 @@ """Scatter operator """ import tvm from tvm import te, autotvm -from ..scatter import _verify_scatter_nd_inputs from ..generic import schedule_extern from .nms import atomic_add from .sort import stable_sort_by_key_thrust @@ -723,7 +722,7 @@ def update_func(dst_ptr, dst_index, update): return out -def scatter_nd(data, indices, shape): +def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape @@ -756,28 +755,28 @@ def scatter_nd(data, indices, shape): ------- ret : tvm.te.Tensor """ - _verify_scatter_nd_inputs(data, indices, shape) - def gen_ir(data_ptr, indices_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the data dimensions. + # number of loops. We do the same thing for all the update dimensions. fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i - fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: - fused_data_dimension *= i + fused_updates_dimension = 1 + for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_updates_dimension *= i fused_shape = 1 - for i in shape: + for i in data_ptr.shape: fused_shape *= i # For now we avoid parallizing over dimensions indexed by `indices` as @@ -789,38 +788,41 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - tdim = min(max_threads, fused_data_dimension) + tdim = min(max_threads, fused_updates_dimension) ib.scope_attr(tx, "thread_extent", tdim) - bdim = ceil_div(fused_data_dimension, tdim) + bdim = ceil_div(fused_updates_dimension, tdim) ib.scope_attr(bx, "thread_extent", bdim) - # zero data - # TODO(tkonolige): could we use topi.full to zero it instead? with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: - index = i * fused_data_dimension + bx * tdim + tx + index = i * fused_updates_dimension + bx * tdim + tx with ib.if_scope(index < fused_shape): - out[index] = tvm.tir.Cast(data_ptr.dtype, 0) + out[index] = data[index] with ib.for_range(0, fused_indices_dimension) as i: j = bx * tdim + tx - with ib.if_scope(j < fused_data_dimension): - offset = fused_data_dimension + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part # of the index into out. for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] - offset *= shape[l] - out[index] += data[i * fused_data_dimension + j] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not supported:", mode) return ib.get() - out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( - [shape], - [data, indices], - lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="scatter_nd_cuda", diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index a376963aa55a..5d3bcc031e25 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -199,30 +199,31 @@ def scatter(data, indices, updates, axis=0): raise ValueError("scatter only support for 1-4 dimensions") -def _verify_scatter_nd_inputs(data, indices, shape): - mdim = int(indices.shape[0]) - assert mdim <= len(shape), ( - f"The first dimension of the indices ({mdim}) must be less than or equal to " - f"the length of the shape of the output ({len(shape)})." - ) - for i in range(len(indices.shape) - 1): - assert indices.shape[i + 1] == data.shape[i], ( - f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " - f"data[{i}] ({data.shape[i]})." - ) - for i in range(mdim, len(shape)): - data_ind = i - mdim + len(indices.shape) - 1 - assert data.shape[data_ind] == shape[i], ( - f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension " - f"of out_shape[{i}] ({shape[i]})." - ) - - assert ( - "int" in indices.dtype - ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}." - - -def scatter_nd(data, indices, shape): +# TODO(mbrookhart): move to type rel +# def _verify_scatter_nd_inputs(data, indices, shape): +# mdim = int(indices.shape[0]) +# assert mdim <= len(shape), ( +# f"The first dimension of the indices ({mdim}) must be less than or equal to " +# f"the length of the shape of the output ({len(shape)})." +# ) +# for i in range(len(indices.shape) - 1): +# assert indices.shape[i + 1] == data.shape[i], ( +# f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " +# f"data[{i}] ({data.shape[i]})." +# ) +# for i in range(mdim, len(shape)): +# data_ind = i - mdim + len(indices.shape) - 1 +# assert data.shape[data_ind] == shape[i], ( +# f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension " +# f"of out_shape[{i}] ({shape[i]})." +# ) +# +# assert ( +# "int" in indices.dtype +# ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}." + + +def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape @@ -248,29 +249,30 @@ def scatter_nd(data, indices, shape): indices : tvm.te.Tensor The indices of the values to extract. - shape : Sequence[int] - The output shape. This must be specified because it cannot be inferred. + updates : tvm.te.Tensor + The updates to apply at the Indices + + mode : string + The update mode for the algorith, either "update" or "add" Returns ------- ret : tvm.te.Tensor """ - _verify_scatter_nd_inputs(data, indices, shape) - def gen_ir(data_ptr, indices_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - # zero data - # TODO(tkonolige): could we use topi.full to zero it instead? fused_shape = 1 - for i in shape: + for i in data.shape: fused_shape *= i with ib.for_range(0, fused_shape) as i: - out[i] = Cast(data_ptr.dtype, 0) + out[i] = data[i] # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary @@ -300,15 +302,20 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): ) ) offset *= shape[l] - out[index] += data[i * fused_data_dimension + j] + if mode == "add": + out[index] += updates[i * fused_data_dimension + j] + elif mode == "update": + out[index] = updates[i * fused_data_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not supported:", mode) return ib.get() out_buf = decl_buffer(shape, data.dtype, "out_buf") return extern( [shape], - [data, indices], - lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="scatter_nd_generic", diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index 8bb3f57e82e4..cdcb01e1c985 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -17,10 +17,9 @@ """Scatter operators for x86""" import tvm from tvm import te -from ..scatter import _verify_scatter_nd_inputs -def scatter_nd(data, indices, shape): +def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape @@ -46,62 +45,68 @@ def scatter_nd(data, indices, shape): indices : tvm.te.Tensor The indices of the values to extract. - shape : Sequence[int] - The output shape. This must be specified because it cannot be inferred. + updates : tvm.te.Tensor + The updates to apply at the Indices + + mode : string + The update mode for the algorith, either "update" or "add" Returns ------- ret : tvm.te.Tensor """ - _verify_scatter_nd_inputs(data, indices, shape) - def gen_ir(data_ptr, indices_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the data dimensions. + # number of loops. We do the same thing for all the update dimensions. fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i - fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: - fused_data_dimension *= i + fused_updates_dimension = 1 + for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_updates_dimension *= i fused_shape = 1 - for i in shape: + for i in data_ptr.shape: fused_shape *= i - # zero data - # TODO(tkonolige): could we use topi.full to zero it instead? with ib.for_range(0, fused_shape) as i: - out[i] = tvm.tir.Cast(data_ptr.dtype, 0) + out[i] = data[i] with ib.for_range(0, fused_indices_dimension) as i: - with ib.for_range(0, fused_data_dimension, kind="parallel") as j: - offset = fused_data_dimension + with ib.for_range(0, fused_updates_dimension, kind="parallel") as j: + offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part # of the index into out. for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] - offset *= shape[l] - out[index] += data[i * fused_data_dimension + j] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not supported:", mode) return ib.get() - out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( - [shape], - [data, indices], - lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="scatter_nd_x86", diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index ad73bb51f2d3..648ef62a04ee 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -23,44 +23,64 @@ @tvm.testing.parametrize_targets def test_scatter_nd(dev, target): - def check_scatter_nd(data, indices, shape, out): + def check_scatter_nd(data, indices, updates, out, mode="add"): implementations = { - "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), - "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), - "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "generic": ( + lambda x, y, z: topi.scatter_nd(x, y, z, mode), + topi.generic.schedule_extern, + ), + "gpu": ( + lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode), + topi.generic.schedule_extern, + ), + "cpu": ( + lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode), + topi.generic.schedule_extern, + ), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, dev, fcompute, fschedule) + tvm.topi.testing.compare_numpy_tvm( + [data, indices, updates], out, target, dev, fcompute, fschedule + ) - data = np.array([2, 3, 0]) + data = np.zeros((2, 2)).astype("int64") indices = np.array([[1, 1, 0], [0, 1, 0]]) - shape = (2, 2) + updates = np.array([2, 3, 0]) out = np.array([[0, 0], [2, 3]]) - check_scatter_nd(data, indices, shape, out) + check_scatter_nd(data, indices, updates, out) - data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + data = np.zeros((2, 2, 2, 2)).astype("int64") indices = np.array([[0, 1], [1, 1]]) - shape = (2, 2, 2, 2) + updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) - check_scatter_nd(data, indices, shape, out) + check_scatter_nd(data, indices, updates, out) - data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") indices = np.array([[1, 0, 0]]) + updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") shape = (2, 1560) - out = np.zeros(shape).astype("float32") - out[1, :] += data[0, :] - out[0, :] += data[1, :] - out[0, :] += data[2, :] - check_scatter_nd(data, indices, shape, out) + data = np.zeros(shape).astype("float32") + out = data.copy() + out[1, :] += updates[0, :] + out[0, :] += updates[1, :] + out[0, :] += updates[2, :] + check_scatter_nd(data, indices, updates, out) - data = np.ones((5, 3)).astype("float64") - indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") - shape = (2, 7, 3) - out = np.zeros(shape).astype("float64") - for i in range(indices.shape[1]): - for j in range(data.shape[1]): - out[indices[0, i], indices[1, i], j] += data[i, j] - check_scatter_nd(data, indices, shape, out) + for mode in ["add", "update"]: + updates = np.ones((5, 3)).astype("float64") + indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( + "int64" + ) + shape = (2, 7, 3) + data = np.random.random(shape).astype("float64") + out = data.copy() + for i in range(indices.shape[1]): + for j in range(updates.shape[1]): + if mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "update": + out[indices[0, i], indices[1, i], j] = updates[i, j] + + check_scatter_nd(data, indices, updates, out, mode) if __name__ == "__main__": From 5d47dcdcab275b807b59db5f8b447cb1f5c8c529 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 26 Apr 2021 15:38:05 -0600 Subject: [PATCH 2/6] passing relay tests, needs better shape checking still --- include/tvm/relay/attrs/transform.h | 5 +- python/tvm/relay/frontend/pytorch.py | 21 ++---- python/tvm/relay/op/_tensor_grad.py | 2 +- python/tvm/relay/op/_transform.py | 2 +- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/op/transform.py | 13 ++-- src/relay/op/tensor/transform.cc | 58 +++++++++-------- tests/python/relay/test_op_level3.py | 85 ++++++++++++++----------- 8 files changed, 95 insertions(+), 93 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a5544c8a8799..113c8209fe6a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -126,10 +126,11 @@ struct ScatterAddAttrs : public tvm::AttrsNode { }; struct ScatterNDAttrs : public tvm::AttrsNode { - Array out_shape; + String mode; TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") { - TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter."); + TVM_ATTR_FIELD(mode).describe( + "Accumulation mode of the scatter, either \"update\" or \"add\"."); } }; diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a31c44a369f9..025942bcfa22 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2118,26 +2118,13 @@ def index_put(self, inputs, input_types): indices = inputs[1] values = inputs[2] accumulate = inputs[3] - # accumulate parameter is ignored. - # torch.index_put default is False but Relay.scatter_nd accumulates values. - # We assume there is no duplicate indices in torch.index_put input if not accumulate: - logging.warning( - "torch.index_put accumulate parameter is False. " - "TVM uses tvm.relay.scatter_nd operator which accumulates values. " - "Make sure there is no duplicate indices in torch.index_put input." - ) - # Relay scatter_nd does not support input tensor - # We assume that torch.index_put is used with empty zero-values input tensor - # scatter_nd will create empty zero-values tensor with a given shape - out_shape = self.infer_shape(in_tensor) - logging.warning( - "tvm.relay.scatter_nd operator does not support input tensor parameter. " - "TVM assumes that torch.index_put is used with empty zero-values input tensor" - ) + mode = "update" + else: + mode = "add" # Combine array of index tensors into one index tensor with shape (N,_) index_tensor = _op.stack(indices, axis=0) - return _op.transform.scatter_nd(values, index_tensor, out_shape) + return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode) def scalar_tensor(self, inputs, input_types): data = inputs[0] diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 5836aebce393..108bef0242fe 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -834,7 +834,7 @@ def gather_nd_grad(orig, grad): Returns the gradient of gather_nd, which is simply scatter_nd. """ data, indices = orig.args - return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)] + return [scatter_nd(zeros_like(data), indices, grad, mode="add"), zeros_like(indices)] @register_gradient("reshape_like") diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 16262833d1bf..25c6848fb7f8 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -145,7 +145,7 @@ def compute_scatter_add(attrs, inputs, output_type): @_reg.register_compute("scatter_nd") def compute_scatter_nd(attrs, inputs, output_type): """Compute definition of scatter_nd""" - return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)] + return [topi.scatter_nd(inputs[0], inputs[1], inputs[2], attrs.mode)] _reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 70e021910ab0..7451b397265f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1288,7 +1288,7 @@ def wrap_compute_scatter_nd(topi_compute): """Wrap scatter_nd topi compute""" def _compute_scatter_nd(attrs, inputs, _): - return [topi_compute(inputs[0], inputs[1], attrs.out_shape)] + return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.mode)] return _compute_scatter_nd diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index f94a00db2fb1..df2686196151 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -310,8 +310,8 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) -def scatter_nd(data, indices, out_shape): - """Scatter values from an array. +def scatter_nd(data, indices, updates, mode="update"): + """Scatter values from an array and update. See :py:func:`tvm.topi.scatter` for how data is scattered. @@ -323,15 +323,18 @@ def scatter_nd(data, indices, out_shape): indices : relay.Expr The index locations to update. - out_shape : Union[Tuple[int], List[int]] - Output shape of the scatter. + updates : relay.Expr + The values to update. + + mode : string + The accumulation mode for scatter. "update" or "add" Returns ------- ret : relay.Expr The computed result. """ - return _make.scatter_nd(data, indices, out_shape) + return _make.scatter_nd(data, indices, updates, mode) def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 11e94cb4b93e..d010d93eb150 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1096,8 +1096,8 @@ TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // `types` contains: [data, indices, result] - ICHECK_EQ(types.size(), 3); + // `types` contains: [data, indices, updates, result] + ICHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* indices = types[1].as(); if (data == nullptr) { @@ -1111,37 +1111,38 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers."; - const auto out_shape = attrs.as()->out_shape; - const IntImmNode* mdim = indices->shape[0].as(); - const size_t kdim = indices->shape.size() - 1; - const size_t ndim = out_shape.size(); - ICHECK_LE(size_t(mdim->value), ndim) - << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " - "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; - // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. - for (size_t i = 0; i < kdim; i++) { - reporter->AssertEQ(indices->shape[i + 1], data->shape[i]); - } - - std::vector oshape; - for (auto& x : out_shape) { - oshape.push_back(x); - } + // TODO(mbrookhart) rethink this logic? + // const auto out_shape = attrs.as()->out_shape; + // const IntImmNode* mdim = indices->shape[0].as(); + // const size_t kdim = indices->shape.size() - 1; + // const size_t ndim = out_shape.size(); + // ICHECK_LE(size_t(mdim->value), ndim) + // << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " + // "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; + // // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. + // for (size_t i = 0; i < kdim; i++) { + // reporter->AssertEQ(indices->shape[i + 1], data->shape[i]); + // } + + // std::vector oshape; + // for (auto& x : out_shape) { + // oshape.push_back(x); + // } + + // // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} + // for (size_t i = mdim->value; i < ndim; i++) { + // reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); + // } - // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} - for (size_t i = mdim->value; i < ndim; i++) { - reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); - } - - reporter->Assign(types[2], TensorType(oshape, data->dtype)); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); return true; } -Expr MakeScatterND(Expr data, Expr indices, const Array out_shape) { +Expr MakeScatterND(Expr data, Expr indices, Expr updates, String mode) { auto attrs = make_object(); - attrs->out_shape = out_shape; + attrs->mode = std::move(mode); static const Op& op = Op::Get("scatter_nd"); - return Call(op, {data, indices}, Attrs(attrs), {}); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND); @@ -1156,9 +1157,10 @@ whose shape is defined by indices. Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape (M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}). )code" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(3) .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor.") .set_support_level(3) .add_type_rel("ScatterND", ScatterNDRel) .set_attr("TOpPattern", kOpaque); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index bf0a7e4952e5..e84b22b30ce1 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1833,38 +1833,39 @@ def test_cumprod(target, dev): @tvm.testing.parametrize_targets def test_scatter_nd(target, dev): - def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5): + def verify_scatter_nd( + data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5 + ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype)) + updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) - out = relay.op.scatter_nd(data, indices, shape) - func = relay.Function([data, indices], out) + out = relay.op.scatter_nd(data, indices, updates, mode) + func = relay.Function([data, indices, updates], out) for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np) + op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) - def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5): + def verify_scatter_nd_with_stack( + data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5 + ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices_vars = [ relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) ] + updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) # test if scatter_nd works in case indices are prepared by another Relay operator indices = relay.op.stack(indices_vars, axis=0) - out = relay.op.scatter_nd(data, indices, shape) + out = relay.op.scatter_nd(data, indices, updates, mode) func = relay.Function( - [ - data, - ] - + indices_vars, + [data, updates] + indices_vars, out, ) - fargs = [ - data_np, - ] + fargs = [data_np, updates_np] for a in indices_np: fargs.append(a) for kind in ["graph", "debug"]: @@ -1872,39 +1873,47 @@ def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, op_res = intrp.evaluate(func)(*fargs) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) - data = np.array([2, 3, 0]) + data = np.zeros((2, 2)).astype("int64") indices = np.array([[1, 1, 0], [0, 1, 0]]) - shape = (2, 2) + updates = np.array([2, 3, 0]) out = np.array([[0, 0], [2, 3]]) - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) + verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd_with_stack(data, indices, updates, out) - data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + data = np.zeros((2, 2, 2, 2)).astype("int64") indices = np.array([[0, 1], [1, 1]]) - shape = (2, 2, 2, 2) + updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) + verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd_with_stack(data, indices, updates, out) - data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") indices = np.array([[1, 0, 0]]) + updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") shape = (2, 1560) - out = np.zeros(shape).astype("float32") - out[1, :] += data[0, :] - out[0, :] += data[1, :] - out[0, :] += data[2, :] - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) - - data = np.ones((5, 3)).astype("float64") - indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") - shape = (2, 7, 3) - out = np.zeros(shape).astype("float64") - for i in range(indices.shape[1]): - for j in range(data.shape[1]): - out[indices[0, i], indices[1, i], j] += data[i, j] - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) + data = np.zeros(shape).astype("float32") + out = data.copy() + out[1, :] += updates[0, :] + out[0, :] += updates[1, :] + out[0, :] += updates[2, :] + verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd_with_stack(data, indices, updates, out) + + for mode in ["add", "update"]: + indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( + "int64" + ) + updates = np.ones((5, 3)).astype("float64") + shape = (2, 7, 3) + data = np.random.random(shape).astype("float64") + out = data.copy() + for i in range(indices.shape[1]): + for j in range(updates.shape[1]): + if mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "update": + out[indices[0, i], indices[1, i], j] = updates[i, j] + verify_scatter_nd(data, indices, updates, out, mode) + verify_scatter_nd_with_stack(data, indices, updates, out, mode) def test_unique(): From 459e88cd300d87713d79a41f5d0bc6867ab925d6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 26 Apr 2021 16:03:47 -0600 Subject: [PATCH 3/6] support ONNX operator --- python/tvm/relay/frontend/onnx.py | 11 +++++++++++ tests/python/frontend/onnx/test_forward.py | 1 - 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cc66cd3c6fe8..a06208429ed3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1329,6 +1329,16 @@ def _impl_v1(cls, inputs, attr, params): return _op.scatter(inputs[0], inputs[1], inputs[2], axis) +class ScatterND(OnnxOpConverter): + """Operator converter for Scatter.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + indices_dim = len(infer_shape(inputs[1])) + axes = list(range(indices_dim)) + return _op.scatter_nd(inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update") + + class Greater(OnnxOpConverter): """Operator logical greater.""" @@ -2820,6 +2830,7 @@ def _get_convert_map(opset): "Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}), "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), + "ScatterND": ScatterND.get_converter(opset), "Squeeze": AttrCvt("squeeze", {"axes": "axis"}), "Unsqueeze": Unsqueeze.get_converter(opset), "Pad": Pad.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 595a3b1c89b3..d9ac7413bb2e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4209,7 +4209,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_round/", "test_scan9_sum/", "test_scan_sum/", - "test_scatternd/", "test_simple_rnn_defaults/", "test_simple_rnn_with_initial_bias/", "test_strnormalizer_export_monday_casesensintive_lower/", From 505c5dbb2bc81b77ac9f099bf80a5ab4532847a5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 26 Apr 2021 16:16:05 -0600 Subject: [PATCH 4/6] add shape checking back in --- python/tvm/topi/cuda/scatter.py | 2 ++ python/tvm/topi/scatter.py | 44 ++++++++++++++-------------- python/tvm/topi/x86/scatter.py | 2 ++ src/relay/op/tensor/transform.cc | 50 ++++++++++++++++++-------------- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index e54d5821341c..d037dce5028a 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -18,6 +18,7 @@ """Scatter operator """ import tvm from tvm import te, autotvm +from ..scatter import _verify_scatter_nd_inputs from ..generic import schedule_extern from .nms import atomic_add from .sort import stable_sort_by_key_thrust @@ -755,6 +756,7 @@ def scatter_nd(data, indices, updates, mode): ------- ret : tvm.te.Tensor """ + _verify_scatter_nd_inputs(data, indices, updates) def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = tvm.tir.ir_builder.create() diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 5d3bcc031e25..96819bdd2eef 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -199,28 +199,27 @@ def scatter(data, indices, updates, axis=0): raise ValueError("scatter only support for 1-4 dimensions") -# TODO(mbrookhart): move to type rel -# def _verify_scatter_nd_inputs(data, indices, shape): -# mdim = int(indices.shape[0]) -# assert mdim <= len(shape), ( -# f"The first dimension of the indices ({mdim}) must be less than or equal to " -# f"the length of the shape of the output ({len(shape)})." -# ) -# for i in range(len(indices.shape) - 1): -# assert indices.shape[i + 1] == data.shape[i], ( -# f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " -# f"data[{i}] ({data.shape[i]})." -# ) -# for i in range(mdim, len(shape)): -# data_ind = i - mdim + len(indices.shape) - 1 -# assert data.shape[data_ind] == shape[i], ( -# f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension " -# f"of out_shape[{i}] ({shape[i]})." -# ) -# -# assert ( -# "int" in indices.dtype -# ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}." +def _verify_scatter_nd_inputs(data, indices, updates): + mdim = int(indices.shape[0]) + assert mdim <= len(data.shape), ( + f"The first dimension of the indices ({mdim}) must be less than or equal to " + f"the length of the shape of the output ({len(shape)})." + ) + for i in range(len(indices.shape) - 1): + assert indices.shape[i + 1] == updates.shape[i], ( + f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " + f"updates[{i}] ({updates.shape[i]})." + ) + for i in range(mdim, len(data.shape)): + data_ind = i - mdim + len(indices.shape) - 1 + assert updates.shape[data_ind] == data.shape[i], ( + f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension " + f"of out_shape[{i}] ({data.shape[i]})." + ) + + assert ( + "int" in indices.dtype + ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}." def scatter_nd(data, indices, updates, mode): @@ -259,6 +258,7 @@ def scatter_nd(data, indices, updates, mode): ------- ret : tvm.te.Tensor """ + _verify_scatter_nd_inputs(data, indices, updates) def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = ir_builder.create() diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index cdcb01e1c985..4e0af4c9dd2c 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -17,6 +17,7 @@ """Scatter operators for x86""" import tvm from tvm import te +from ..scatter import _verify_scatter_nd_inputs def scatter_nd(data, indices, updates, mode): @@ -55,6 +56,7 @@ def scatter_nd(data, indices, updates, mode): ------- ret : tvm.te.Tensor """ + _verify_scatter_nd_inputs(data, indices, updates) def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d010d93eb150..e937cb0c7b1f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1100,6 +1100,7 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* indices = types[1].as(); + const auto* updates = types[2].as(); if (data == nullptr) { ICHECK(types[0].as()) << "ScatterND: expect input data type to be TensorType but got " << types[0]; @@ -1110,29 +1111,34 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, << "ScatterND: expect indices type to be TensorType but got " << types[1]; return false; } + if (updates == nullptr) { + ICHECK(types[2].as()) + << "ScatterND: expect updates type to be TensorType but got " << types[2]; + return false; + } ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers."; - // TODO(mbrookhart) rethink this logic? - // const auto out_shape = attrs.as()->out_shape; - // const IntImmNode* mdim = indices->shape[0].as(); - // const size_t kdim = indices->shape.size() - 1; - // const size_t ndim = out_shape.size(); - // ICHECK_LE(size_t(mdim->value), ndim) - // << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " - // "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; - // // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. - // for (size_t i = 0; i < kdim; i++) { - // reporter->AssertEQ(indices->shape[i + 1], data->shape[i]); - // } - - // std::vector oshape; - // for (auto& x : out_shape) { - // oshape.push_back(x); - // } - - // // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} - // for (size_t i = mdim->value; i < ndim; i++) { - // reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); - // } + + const auto out_shape = data->shape; + const IntImmNode* mdim = indices->shape[0].as(); + const size_t kdim = indices->shape.size() - 1; + const size_t ndim = out_shape.size(); + ICHECK_LE(size_t(mdim->value), ndim) + << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " + "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; + // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. + for (size_t i = 0; i < kdim; i++) { + reporter->AssertEQ(indices->shape[i + 1], updates->shape[i]); + } + + std::vector oshape; + for (auto& x : out_shape) { + oshape.push_back(x); + } + + // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} + for (size_t i = mdim->value; i < ndim; i++) { + reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); + } reporter->Assign(types[3], TensorType(data->shape, data->dtype)); return true; From 9fc43c1fe59cf6f9e33ae9108006e84cf52eb98e Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 26 Apr 2021 16:20:55 -0600 Subject: [PATCH 5/6] fix lint --- python/tvm/relay/frontend/onnx.py | 4 +++- python/tvm/topi/scatter.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a06208429ed3..05241a677e59 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1336,7 +1336,9 @@ class ScatterND(OnnxOpConverter): def _impl_v11(cls, inputs, attr, params): indices_dim = len(infer_shape(inputs[1])) axes = list(range(indices_dim)) - return _op.scatter_nd(inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update") + return _op.scatter_nd( + inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update" + ) class Greater(OnnxOpConverter): diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 96819bdd2eef..959c535412c8 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate +from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate from ..te import extern, hybrid From f9954bac2f9c4846a3be072a16078a1cbac8c97a Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 26 Apr 2021 17:17:51 -0600 Subject: [PATCH 6/6] update docstring --- python/tvm/topi/cuda/scatter.py | 20 +++++++++++++------- python/tvm/topi/scatter.py | 15 +++++++++------ python/tvm/topi/x86/scatter.py | 15 +++++++++------ 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index d037dce5028a..cee13d7e01a2 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -726,8 +726,9 @@ def update_func(dst_ptr, dst_index, update): def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. - Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape - (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}), + scatter_nd computes .. code-block:: @@ -737,9 +738,9 @@ def scatter_nd(data, indices, updates, mode): x_M, ..., x_{N-1} - ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]) - all other entries in the output are 0. Repeated indices are summed. + where the update function f is determinted by the mode. Parameters ---------- @@ -749,8 +750,13 @@ def scatter_nd(data, indices, updates, mode): indices : tvm.te.Tensor The indices of the values to extract. - shape : Sequence[int] - The output shape. This must be specified because it cannot be inferred. + updates : tvm.te.Tensor + The updates to apply at the Indices + + mode : string + The update mode for the algorithm, either "update" or "add" + If update, the update values will replace the input data + If add, the update values will be added to the input data Returns ------- @@ -816,7 +822,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] else: - raise NotImplementedError("scatter_nd mode not supported:", mode) + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 959c535412c8..d7b008c4c33f 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -225,8 +225,9 @@ def _verify_scatter_nd_inputs(data, indices, updates): def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. - Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape - (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}), + scatter_nd computes .. code-block:: @@ -236,9 +237,9 @@ def scatter_nd(data, indices, updates, mode): x_M, ..., x_{N-1} - ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]) - all other entries in the output are 0. Repeated indices are summed. + where the update function f is determinted by the mode. Parameters ---------- @@ -252,7 +253,9 @@ def scatter_nd(data, indices, updates, mode): The updates to apply at the Indices mode : string - The update mode for the algorith, either "update" or "add" + The update mode for the algorithm, either "update" or "add" + If update, the update values will replace the input data + If add, the update values will be added to the input data Returns ------- @@ -307,7 +310,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): elif mode == "update": out[index] = updates[i * fused_data_dimension + j] else: - raise NotImplementedError("scatter_nd mode not supported:", mode) + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index 4e0af4c9dd2c..5eb5e6e99b6c 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -23,8 +23,9 @@ def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. - Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape - (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}), + scatter_nd computes .. code-block:: @@ -34,9 +35,9 @@ def scatter_nd(data, indices, updates, mode): x_M, ..., x_{N-1} - ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]) - all other entries in the output are 0. Repeated indices are summed. + where the update function f is determinted by the mode. Parameters ---------- @@ -50,7 +51,9 @@ def scatter_nd(data, indices, updates, mode): The updates to apply at the Indices mode : string - The update mode for the algorith, either "update" or "add" + The update mode for the algorithm, either "update" or "add" + If update, the update values will replace the input data + If add, the update values will be added to the input data Returns ------- @@ -100,7 +103,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] else: - raise NotImplementedError("scatter_nd mode not supported:", mode) + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get()