From 109bd6128d37c9f727ac3b73bef40e5508e4d7e2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 09:20:03 +0300 Subject: [PATCH 01/22] remove scatter attr class --- include/tvm/relay/attrs/transform.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 7680883248e0..5667d0e3cb3d 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -148,14 +148,6 @@ struct ReshapeLikeAttrs : public tvm::AttrsNode { } }; // struct ReshapeLikeAttrs -struct ScatterAttrs : public tvm::AttrsNode { - Integer axis; - - TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") { - TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); - } -}; - struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; String reduction; From 36bb4a88c6542f96ddcfcbb28c61e4cd36fda878 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 09:29:32 +0300 Subject: [PATCH 02/22] update pytorch: scatter was replaced by scatter_elements --- python/tvm/relay/frontend/pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 635cb960a829..62851506fc89 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -496,7 +496,7 @@ def slice(self, inputs, input_types): end[dim] = target_end else: target_end = _expr.const(target_end) - end = _op.scatter( + end = _op.scatter_elements( end, _op.expand_dims(_expr.const(dim), axis=0), _op.expand_dims(target_end, axis=0), @@ -508,7 +508,7 @@ def slice(self, inputs, input_types): ttype = self.infer_type(target_end).dtype if str(ttype) != axis_dtype: target_end = _op.cast(target_end, axis_dtype) - end = _op.scatter( + end = _op.scatter_elements( end, _op.expand_dims(_expr.const(dim), axis=0), _op.expand_dims(target_end, axis=0), @@ -2556,7 +2556,7 @@ def scatter(self, inputs, input_types): axis = int(inputs[1]) index = inputs[2] src = inputs[3] - return _op.transform.scatter(data, index, src, axis) + return _op.scatter_elements(data, index, src, axis) def index_put(self, inputs, input_types): in_tensor = inputs[0] @@ -2569,7 +2569,7 @@ def index_put(self, inputs, input_types): mode = "add" # Combine array of index tensors into one index tensor with shape (N,_) index_tensor = _op.stack(indices, axis=0) - return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode) + return _op.scatter_nd(in_tensor, index_tensor, values, mode) def scalar_tensor(self, inputs, input_types): data = inputs[0] From cee50173714b8aad3dda5b1d170e74f21a231d2c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 09:39:40 +0300 Subject: [PATCH 03/22] remove scatter compute and strategy registration --- python/tvm/relay/op/_transform.py | 10 ---------- python/tvm/relay/op/strategy/generic.py | 21 --------------------- python/tvm/topi/generic/search.py | 16 ---------------- 3 files changed, 47 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index f28c28ce62a6..da93a71ce0cb 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -104,15 +104,6 @@ def compute_strided_set(attrs, inputs, output_type): # argwhere _reg.register_strategy("argwhere", strategy.argwhere_strategy) -# scatter -@_reg.register_compute("scatter") -def compute_scatter(attrs, inputs, output_type): - """Compute definition of scatter""" - return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)] - - -_reg.register_strategy("scatter", strategy.scatter_strategy) - # sparse_fill_empty_rows @_reg.register_compute("sparse_fill_empty_rows") def compute_sparse_fill_empty_rows(attrs, inputs, output_type): @@ -677,7 +668,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims): return ValueError("Does not support rank higher than 5 in argwhere") -_reg.register_shape_func("scatter", False, elemwise_shape_func) _reg.register_shape_func("scatter_elements", False, elemwise_shape_func) _reg.register_shape_func("scatter_nd", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 4641fb18f7ba..2b12861e040e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1548,27 +1548,6 @@ def proposal_strategy(attrs, inputs, out_type, target): return strategy -# scatter -@override_native_generic_func("scatter_strategy") -def scatter_strategy(attrs, outs, out_type, target): - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_scatter(topi.scatter), - wrap_topi_schedule(topi.generic.schedule_scatter), - name="scatter.generic", - ) - return strategy - - -def wrap_compute_scatter(topi_compute): - """Wrap scatter topi compute""" - - def _compute_scatter(attrs, inputs, _): - return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)] - - return _compute_scatter - - # scatter_elements @override_native_generic_func("scatter_elements_strategy") def scatter_elements_strategy(attrs, inputs, out_type, target): diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index 826194e75c2a..9a80e678c212 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -36,22 +36,6 @@ def schedule_argwhere(outs): return _default_schedule(outs, False) -def schedule_scatter(outs): - """Schedule for scatter operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of scatter. - - Returns - ------- - s: Schedule - The computation schedule for the op. - """ - return _default_schedule(outs, False) - - def schedule_sparse_fill_empty_rows(outs): return _default_schedule(outs, False) From d6c156a84794db0583b26a05adb43d620d9159b3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 09:40:51 +0300 Subject: [PATCH 04/22] remove scatter attrs registration --- python/tvm/relay/op/op_attrs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 0214ae8a46b6..4e9a9a4707a1 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -529,11 +529,6 @@ class RequantizeAttrs(Attrs): """Attributes used in requantize operators""" -@tvm._ffi.register_object("relay.attrs.ScatterAttrs") -class ScatterAttrs(Attrs): - """Attributes used in scatter operators""" - - @tvm._ffi.register_object("relay.attrs.SequenceMaskAttrs") class SequenceMaskAttrs(Attrs): """Attributes used in sequence_mask operators""" From 5ef7f8de2946089eb83f292cb23fee9c778eb3ae Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 10:50:57 +0300 Subject: [PATCH 05/22] update onnx front-end: replace _op.scatter by _op.scatter_elements, add checks --- python/tvm/relay/frontend/onnx.py | 40 +++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2a1890627225..19af5850a547 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1967,7 +1967,7 @@ def _impl_v11(cls, inputs, attr, params): # Create a tensor of zeros then scatter our data through it. zeros_tensor = _op.zeros(total_output_shape, data_type) # We need to flatten all our tensors before scattering. - flat_tensor = _op.scatter( + flat_tensor = _op.scatter_elements( _op.reshape(zeros_tensor, [-1]), _op.reshape(indices, [-1]), _op.reshape(data, [-1]), @@ -2734,15 +2734,15 @@ def has_static_axes(): # Update the starts and ends according to axes if required. if axes is not None: data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype) - starts = _op.scatter( + starts = _op.scatter_elements( _op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype), axes, starts, axis=0, ) - ends = _op.scatter(data_shape, axes, ends, axis=0) + ends = _op.scatter_elements(data_shape, axes, ends, axis=0) if steps is not None: - steps = _op.scatter( + steps = _op.scatter_elements( _op.const([1] * data_rank, dtype=infer_type(steps).checked_type.dtype), axes, steps, @@ -2848,9 +2848,35 @@ class Scatter(OnnxOpConverter): """Operator converter for Scatter.""" @classmethod - def _impl_v9(cls, inputs, attr, params): + def _args_check(cls, inputs, attr): + assert ( + len(inputs) == 3 + ), "Scatter takes 3 inputs (data, indices, updates), {} given".format(len(inputs)) + assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"] + + data_rank = len(infer_shape(inputs[0])) + assert data_rank > 0, "Data rank higher than 0 is expected" + indices_shape = infer_shape(inputs[1]) + indices_rank = len(indices_shape) + assert indices_rank == data_rank, "Indices rank is not the same as data one" + updates_shape = infer_shape(inputs[2]) + updates_rank = len(updates_shape) + assert updates_rank == data_rank, "Updates rank is not the same as data one" + + for i in range(data_rank): + assert ( + indices_shape[i] == updates_shape[i] + ), "Indices dimension size should be the same as updates one" + axis = attr.get("axis", 0) - return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + assert -data_rank <= axis < data_rank, "Axis is out of bounds" + + return axis + + @classmethod + def _impl_v9(cls, inputs, attr, params): + axis = cls._args_check(inputs, attr) + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis) class ScatterElements(OnnxOpConverter): @@ -4991,7 +5017,7 @@ def _index_put(cls, inputs, attr, params): else: mode = "add" index_tensor = _op.stack(indices, axis=0) - return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode) + return _op.scatter_nd(in_tensor, index_tensor, values, mode) @classmethod def _reshape(cls, inputs, attr, params): From 87ec6e071c6a562ec64082f6dab64bd96b0d3e6c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 14:18:47 +0300 Subject: [PATCH 06/22] update oneflow front-end --- python/tvm/relay/frontend/oneflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index ff4b5a5bcc42..1aba9e64190f 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -1227,7 +1227,7 @@ class Scatter(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): axis = attrs.get("axis", 0) - return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis) class Unsqueeze(OneFlowOpConverter): From 82855ee71fb0a73824fd1a3236c220ae5a5bc31b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 14:21:21 +0300 Subject: [PATCH 07/22] update paddlepaddle front-end --- python/tvm/relay/frontend/paddlepaddle.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index e688369a072a..78895e4b49e0 100755 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1741,10 +1741,10 @@ def convert_scatter(g, op, block): index = _op.transform.broadcast_to(index, shape) if overwrite: - out = _op.scatter(x, index, updates, axis=0) + out = _op.scatter_elements(x, index, updates, axis=0) else: out = _op.scatter_elements(_op.zeros_like(x), index, updates, axis=0, reduction="add") - out += _op.scatter(x, index, _op.zeros_like(updates), axis=0) + out += _op.scatter_elements(x, index, _op.zeros_like(updates), axis=0) g.add_node(op.output("Out")[0], out) @@ -1826,7 +1826,7 @@ def convert_slice(g, op, block): if len(axes) < dims: if isinstance(starts, _expr.Expr): - starts = _op.scatter( + starts = _op.scatter_elements( _op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype), indices, starts, @@ -1857,7 +1857,7 @@ def convert_slice(g, op, block): if len(axes) < dims: if isinstance(ends, _expr.Expr): - ends = _op.scatter( + ends = _op.scatter_elements( _expr.const( np.array([np.iinfo(np.int32).max] * dims), dtype=infer_type(ends).checked_type.dtype, @@ -1892,7 +1892,7 @@ def convert_slice(g, op, block): if len(axes) < dims: if isinstance(strides, _expr.Expr): - strides = _op.scatter( + strides = _op.scatter_elements( _expr.const( np.array([1] * dims), dtype=infer_type(strides).checked_type.dtype, From 0dcf131a007905341715e4f0b0ad2426063630c3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 14:25:26 +0300 Subject: [PATCH 08/22] update pytorch utils --- python/tvm/relay/frontend/pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index da4c9e039e54..7de1248bda77 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -331,7 +331,7 @@ def do_where(levels, _): scatter_indices = is_op("repeat")(scatter_indices) scatter_indices = is_op("repeat")(scatter_indices) - scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i]) + scatter_res = is_op("scatter_elements")(scatter_res, scatter_indices, roi_align_results[i]) return is_op("reshape")(scatter_res) From 026a231c43226550fac63a12791a23bb0feab9cc Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 14:27:34 +0300 Subject: [PATCH 09/22] remove front-end scatter definition --- python/tvm/relay/op/transform.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 833d14eb5897..2d39359bf855 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -353,31 +353,6 @@ def argwhere(condition): return _make.argwhere(condition) -def scatter(data, indices, updates, axis): - """Update data at positions defined by indices with values in updates. - - Parameters - ---------- - data : relay.Expr - The input data to the operator. - - indices : relay.Expr - The index locations to update. - - updates : relay.Expr - The values to update. - - axis : int - The axis to scatter on. - - Returns - ------- - ret : relay.Expr - The computed result. - """ - return _make.scatter(data, indices, updates, axis) - - def scatter_elements(data, indices, updates, axis=0, reduction="update"): """Scatter elements with updating data by reduction of values in updates at positions defined by indices. From 8bd9965a1d5824e0b686fb6baa7dd47bf81d23c7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 14:37:09 +0300 Subject: [PATCH 10/22] fix scatter strategy for rocm --- python/tvm/relay/op/strategy/rocm.py | 14 +++++++++----- python/tvm/topi/cuda/scatter.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 89cac0db4ab9..6e6053ee0603 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -105,23 +105,27 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target): return strategy -@scatter_strategy.register(["rocm"]) -def scatter_cuda(attrs, inputs, out_type, target): +@scatter_elements_strategy.register(["rocm"]) +def scatter_elements_cuda(attrs, inputs, out_type, target): """scatter rocm strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_scatter(topi.cuda.scatter), + wrap_compute_scatter_elements(topi.cuda.scatter_elements), wrap_topi_schedule(topi.cuda.schedule_scatter), name="scatter.rocm", plevel=10, ) rank = len(inputs[0].shape) + reduction = attrs.get("reduction", None) + if reduction is None: + reduction = b"update" + reduction = reduction.decode("utf-8") - with SpecializedCondition(rank == 1): + with SpecializedCondition(rank == 1 and reduction == "update"): if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"): strategy.add_implementation( - wrap_compute_scatter(topi.cuda.scatter_via_sort), + wrap_compute_scatter_elements(topi.cuda.scatter_via_sort), wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), name="scatter_via_sort.rocm", plevel=9, # use the sequential version by default diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index c88c3086f317..e4ad6eea9553 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -540,7 +540,7 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out): @autotvm.register_topi_compute("scatter_via_sort.cuda") -def scatter_via_sort(cfg, data, indices, updates, axis=0): +def scatter_via_sort(cfg, data, indices, updates, axis=0, _): """Update data at positions defined by indices with values in updates Parameters From 0e1e3a7dd1b0c087ddd5c55600f9ad348d8815c6 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 14:40:21 +0300 Subject: [PATCH 11/22] small update --- python/tvm/relay/transform/mixed_precision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 5018ba9ba9a7..f6bb8b815085 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -63,6 +63,8 @@ "tile", "dyn.tile", "scatter", + "scatter_elements", + "scatter_nd", "full", "dyn.full", "nn.depth_to_space", From 15e4cd642c94d2be9d1b9b812eefe4211c3534fd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 14:53:25 +0300 Subject: [PATCH 12/22] remove scatter definition in back-end --- src/relay/op/tensor/transform.cc | 48 -------------------------------- 1 file changed, 48 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 907141c9cb6a..6ecabd4b2aba 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1095,54 +1095,6 @@ non-zero)doc" TVM_ADD_FILELINE) .set_attr("TOpPattern", kOpaque) .set_support_level(10); -// Scatter -TVM_REGISTER_NODE_TYPE(ScatterAttrs); - -// Scatter -bool ScatterRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(num_inputs, 3); - ICHECK_EQ(types.size(), 4); - auto data = types[0].as(); - if (data == nullptr) { - return false; - } - auto indices = types[1].as(); - if (indices == nullptr) { - return false; - } - auto updates = types[2].as(); - if (updates == nullptr) { - return false; - } - ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) - << "indices of scatter must be tensor of integer"; - const auto param = attrs.as(); - ICHECK(param != nullptr); - reporter->Assign(types[3], TensorType(data->shape, data->dtype)); - return true; -} - -TVM_REGISTER_GLOBAL("relay.op._make.scatter") - .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { - auto attrs = make_object(); - attrs->axis = std::move(axis); - static const Op& op = Op::Get("scatter"); - return Call(op, {data, indices, updates}, Attrs(attrs), {}); - }); - -RELAY_REGISTER_OP("scatter") - .describe( - R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE) - .set_num_inputs(3) - .add_argument("data", "Tensor", "The input data tensor.") - .add_argument("indices", "Tensor", "The indices location tensor.") - .add_argument("updates", "Tensor", "The values to update the input with.") - .add_type_rel("Scatter", ScatterRel) - .set_attr("TOpIsStateful", false) - .set_attr("TOpPattern", kOpaque) - .set_support_level(10); - // scatter_elements operator TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); From a6894f8a56fcaa812e9a01fe752c9f0180e3055c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 15:09:09 +0300 Subject: [PATCH 13/22] remove scatter strategy for cuda, gpu. transfer special case to scatter_elements --- python/tvm/relay/op/strategy/cuda.py | 34 ++++++++++------------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index e0229a615d50..3c822f70573a 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1062,23 +1062,27 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): return strategy -@scatter_strategy.register(["cuda", "gpu"]) -def scatter_cuda(attrs, inputs, out_type, target): - """scatter cuda strategy""" +@scatter_elements_strategy.register(["cuda", "gpu"]) +def scatter_elements_cuda(attrs, inputs, out_type, target): + """scatter elements cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_scatter(topi.cuda.scatter), - wrap_topi_schedule(topi.cuda.schedule_scatter), - name="scatter.cuda", + wrap_compute_scatter_elements(topi.cuda.scatter_elements), + wrap_topi_schedule(topi.cuda.schedule_extern), + name="scatter_elements.cuda", plevel=10, ) rank = len(inputs[0].shape) + reduction = attrs.get("reduction", None) + if reduction is None: + reduction = b"update" + reduction = reduction.decode("utf-8") - with SpecializedCondition(rank == 1): + with SpecializedCondition(rank == 1 and reduction == "update"): if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"): strategy.add_implementation( - wrap_compute_scatter(topi.cuda.scatter_via_sort), + wrap_compute_scatter_elements(topi.cuda.scatter_via_sort), wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), name="scatter_via_sort.cuda", plevel=9, # use the sequential version by default @@ -1086,20 +1090,6 @@ def scatter_cuda(attrs, inputs, out_type, target): return strategy -@scatter_elements_strategy.register(["cuda", "gpu"]) -def scatter_elements_cuda(attrs, inputs, out_type, target): - """scatter elements cuda strategy""" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_scatter_elements(topi.cuda.scatter_elements), - wrap_topi_schedule(topi.cuda.schedule_extern), - name="scatter_elements.cuda", - plevel=10, - ) - # TODO(vvchernov): There is possible specification for rank=1 as for scatter - return strategy - - @scatter_nd_strategy.register(["cuda", "gpu"]) def scatter_nd_cuda(attrs, inputs, out_type, target): """scatter_nd cuda strategy""" From c708acf7e4699a1ca5784941eecaf42d4cd9d7bc Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 15:45:25 +0300 Subject: [PATCH 14/22] fix test --- tests/python/relay/test_op_level3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f18e935b57c8..493bf00fc6ad 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1002,7 +1002,7 @@ def verify_scatter(dshape, ishape, axis=0, indices_dtype="int64"): d = relay.var("d", relay.TensorType(dshape, "float32")) i = relay.var("i", relay.TensorType(ishape, indices_dtype)) u = relay.var("u", relay.TensorType(ishape, "float32")) - z = relay.op.scatter(d, i, u, axis) + z = relay.op.scatter_elements(d, i, u, axis) func = relay.Function([d, i, u], z) @@ -1055,7 +1055,7 @@ def test_dynamic_scatter(self, target, dev, executor_kind, dshape, ishape, axis) d = relay.var("d", relay.TensorType([relay.Any() for i in range(len(dshape))], "float32")) i = relay.var("i", relay.TensorType([relay.Any() for i in range(len(ishape))], "int64")) u = relay.var("u", relay.TensorType([relay.Any() for i in range(len(ishape))], "float32")) - z = relay.op.scatter(d, i, u, axis) + z = relay.op.scatter_elements(d, i, u, axis) func = relay.Function([d, i, u], z) From bd111dd523786b1f78100ffcb4d61ae568582329 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 17 Feb 2023 10:11:14 +0300 Subject: [PATCH 15/22] small fix --- python/tvm/relay/frontend/onnx.py | 6 +++--- python/tvm/topi/cuda/scatter.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 19af5850a547..7c55bfefb71d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2849,9 +2849,9 @@ class Scatter(OnnxOpConverter): @classmethod def _args_check(cls, inputs, attr): - assert ( - len(inputs) == 3 - ), "Scatter takes 3 inputs (data, indices, updates), {} given".format(len(inputs)) + assert len(inputs) == 3, "Scatter takes 3 inputs (data, indices, updates), {} given".format( + len(inputs) + ) assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"] data_rank = len(infer_shape(inputs[0])) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index e4ad6eea9553..496775b206fa 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -540,7 +540,7 @@ def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out): @autotvm.register_topi_compute("scatter_via_sort.cuda") -def scatter_via_sort(cfg, data, indices, updates, axis=0, _): +def scatter_via_sort(cfg, data, indices, updates, axis=0, reduction="add"): """Update data at positions defined by indices with values in updates Parameters @@ -562,6 +562,7 @@ def scatter_via_sort(cfg, data, indices, updates, axis=0, _): ret : relay.Expr The computed result. """ + assert reduction == "add" if axis < 0: axis += len(data.shape) assert axis == 0 and len(data.shape) == 1, "sorting based scatter only supported for 1d input" From c6fbd9aa02e5e71398d3b94b07491f1af334a752 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 17 Feb 2023 12:27:28 +0300 Subject: [PATCH 16/22] upstream scatter with torch description --- python/tvm/relay/frontend/pytorch.py | 45 +++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 62851506fc89..dde3d860136e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2552,11 +2552,54 @@ def nonzero_numpy(self, inputs, input_types): return self.nonzero(inputs, input_types, is_numpy_style=False) def scatter(self, inputs, input_types): + assert len(inputs) == 5, ( + "scatter takes 5 inputs (data, dim, index, src, reduce), " + + "but {} given".format(len(inputs)) + ) data = inputs[0] axis = int(inputs[1]) index = inputs[2] src = inputs[3] - return _op.scatter_elements(data, index, src, axis) + reduce = inputs[4] + + data_shape = self.infer_shape(data) + data_rank = len(data_shape) + index_shape = self.infer_shape(index) + index_rank = len(index_shape) + # When index is empty, the operation returns data unchanged + if index_rank == 0: + return data + assert self.infer_type(src).dtype == self.infer_type(data).dtype, ( + "The same data types for data and src are expected" + ) + if np.isscalar(src): + assert self.infer_type(src).dtype == "float", "Scalar source can be float only" + src = _op.broadcast_to_like(src, data_shape) + src_shape = data_shape + else: + src_shape = self.infer_shape(inputs[3]) + src_rank = len(src_shape) + assert data_rank == index_rank, "Index rank is not the same as data rank" + assert data_rank == src_rank, "Src rank is not the same as data rank" + + assert 0 <= axis < data_rank, "Dim is out of bounds" + + for i in range(data_rank): + assert index_shape[i] <= src_shape[i], "Index dim size should be less than src one" + if i != axis: + assert ( + index_shape[i] <= data_shape[i] + ), "Index dim size should be less than data one" + + if reduce is None: + reduce = "update" + elif reduce == "multiply": + reduce = "mul" + assert reduce in ["update", "add", "mul"], ( + "reduce arg is expected from \"add\", \"multiply\" or None" + ) + + return _op.scatter_elements(data, index, src, axis, reduce) def index_put(self, inputs, input_types): in_tensor = inputs[0] From 5aa9e7b5c7671a4114315834bf2ec3893cc05e2f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 19 Feb 2023 18:15:25 +0300 Subject: [PATCH 17/22] last upstream of scatter in pytorch front-end --- python/tvm/relay/frontend/pytorch.py | 25 +++++++++++-------- tests/python/frontend/pytorch/test_forward.py | 9 ++++++- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index dde3d860136e..66b4549b4ee7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2552,32 +2552,33 @@ def nonzero_numpy(self, inputs, input_types): return self.nonzero(inputs, input_types, is_numpy_style=False) def scatter(self, inputs, input_types): - assert len(inputs) == 5, ( - "scatter takes 5 inputs (data, dim, index, src, reduce), " + assert len(inputs) == 4 or len(inputs) == 5, ( + "scatter takes 4 or 5 inputs: data, dim, index, src, reduce (optional), " + "but {} given".format(len(inputs)) ) data = inputs[0] axis = int(inputs[1]) index = inputs[2] src = inputs[3] - reduce = inputs[4] + if len(inputs) == 5: + reduce = inputs[4] + else: + reduce = "update" data_shape = self.infer_shape(data) data_rank = len(data_shape) index_shape = self.infer_shape(index) index_rank = len(index_shape) # When index is empty, the operation returns data unchanged - if index_rank == 0: + if self.is_empty_shape(index_shape): return data - assert self.infer_type(src).dtype == self.infer_type(data).dtype, ( - "The same data types for data and src are expected" - ) + if np.isscalar(src): assert self.infer_type(src).dtype == "float", "Scalar source can be float only" src = _op.broadcast_to_like(src, data_shape) src_shape = data_shape else: - src_shape = self.infer_shape(inputs[3]) + src_shape = self.infer_shape(src) src_rank = len(src_shape) assert data_rank == index_rank, "Index rank is not the same as data rank" assert data_rank == src_rank, "Src rank is not the same as data rank" @@ -2595,9 +2596,11 @@ def scatter(self, inputs, input_types): reduce = "update" elif reduce == "multiply": reduce = "mul" - assert reduce in ["update", "add", "mul"], ( - "reduce arg is expected from \"add\", \"multiply\" or None" - ) + assert reduce in [ + "update", + "add", + "mul", + ], 'reduce arg is expected from "add", "multiply" or None' return _op.scatter_elements(data, index, src, axis, reduce) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 21defa6a59b2..f7acb1b62150 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4232,12 +4232,19 @@ def test_fn_scatter_add(dim): verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], targets) verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets) - # Check empty indices for scatter_add + # Check empty indices in_data = torch.zeros(2, 4) in_index = torch.empty((0,)) in_src = torch.rand(2, 1) + verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets) verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets) + # Check scalar source + in_data = torch.zeros(2, 4) + in_index = torch.tensor([[2], [3]]) + in_src = torch.rand(size=[]) + verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets) + def test_forward_scatter_reduce(): """test_forward_scatter_reduce""" From 1a0ce4f1e0d2ef4b98c4e14b859a65d42f73c933 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 19 Feb 2023 20:38:15 +0300 Subject: [PATCH 18/22] fix reduction attribute in cuda strategy --- python/tvm/relay/op/strategy/cuda.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 3c822f70573a..c6ea692a8d0e 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1074,12 +1074,8 @@ def scatter_elements_cuda(attrs, inputs, out_type, target): ) rank = len(inputs[0].shape) - reduction = attrs.get("reduction", None) - if reduction is None: - reduction = b"update" - reduction = reduction.decode("utf-8") - with SpecializedCondition(rank == 1 and reduction == "update"): + with SpecializedCondition(rank == 1 and attrs.reduction == "update"): if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"): strategy.add_implementation( wrap_compute_scatter_elements(topi.cuda.scatter_via_sort), From 736564e046f5406ab8a5bf6f52aa6d47f3a4ee76 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 20 Feb 2023 08:33:58 +0300 Subject: [PATCH 19/22] set scalar to test instead of tensor. update check for dynamic dim --- python/tvm/relay/frontend/pytorch.py | 15 ++++++++++----- src/relay/op/tensor/transform.cc | 1 - tests/python/frontend/pytorch/test_forward.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 66b4549b4ee7..a7da26fb5e75 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2586,11 +2586,16 @@ def scatter(self, inputs, input_types): assert 0 <= axis < data_rank, "Dim is out of bounds" for i in range(data_rank): - assert index_shape[i] <= src_shape[i], "Index dim size should be less than src one" - if i != axis: - assert ( - index_shape[i] <= data_shape[i] - ), "Index dim size should be less than data one" + index_dim = index_shape[i] + src_dim = src_shape[i] + data_dim = data_shape[i] + # Skip check for dynamic dimensions + if not any([isinstance(index_dim, tvm.tir.Any), isinstance(src_dim, tvm.tir.Any)]): + assert index_dim <= src_dim, "Index dim size should be less than src one" + if i != axis and not any( + [isinstance(index_dim, tvm.tir.Any), isinstance(data_dim, tvm.tir.Any)] + ): + assert index_dim <= data_dim, "Index dim size should be less than data one" if reduce is None: reduce = "update" diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 6ecabd4b2aba..1bae1a4d9aba 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1120,7 +1120,6 @@ bool ScatterElementsRel(const Array& types, int num_inputs, const Attrs& a << "ScatterElements: expect updates type to be TensorType but got " << types[2]; return false; } - // TODO(vvchernov): ONNX requires int32 and int64 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) << "ScatterElements: indices must be a tensor of integers."; diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f7acb1b62150..1e8132cb5933 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4242,7 +4242,7 @@ def test_fn_scatter_add(dim): # Check scalar source in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) - in_src = torch.rand(size=[]) + in_src = np.random.random() # scalar verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets) From 972677afb34a66f97430f5a898d480d0ca6e1ca0 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 20 Feb 2023 10:39:15 +0300 Subject: [PATCH 20/22] skip scalar source check in tests for scatter due to issue on torch side --- tests/python/frontend/pytorch/test_forward.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 1e8132cb5933..b78b64421650 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4240,10 +4240,8 @@ def test_fn_scatter_add(dim): verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets) # Check scalar source - in_data = torch.zeros(2, 4) - in_index = torch.tensor([[2], [3]]) - in_src = np.random.random() # scalar - verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets) + # TODO(vvchernov): Scalar source is supported on TVM side, but torch failes with + # input Tuple(Tensor, Tensor, float). What does scalar mean for torch in this case? def test_forward_scatter_reduce(): From 441e00103fada9e6e0a93ebb91e4ca77dee1f7a2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 22 Feb 2023 08:47:36 +0300 Subject: [PATCH 21/22] remove scatter op implementation from topi/cuda --- python/tvm/relay/op/strategy/rocm.py | 4 +- python/tvm/topi/cuda/scatter.py | 437 +-------------------------- 2 files changed, 5 insertions(+), 436 deletions(-) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 6e6053ee0603..71d6b4524c95 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -111,8 +111,8 @@ def scatter_elements_cuda(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_scatter_elements(topi.cuda.scatter_elements), - wrap_topi_schedule(topi.cuda.schedule_scatter), - name="scatter.rocm", + wrap_topi_schedule(topi.cuda.schedule_extern), + name="scatter_elements.rocm", plevel=10, ) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 496775b206fa..39ef5a5a42ca 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -14,446 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument -"""Scatter operator """ +# pylint: disable=invalid-name +"""Scatter operators""" import tvm from tvm import te, tir, autotvm from ..scatter import _verify_scatter_nd_inputs from ..generic import schedule_extern from .nms import atomic_add from .sort import stable_sort_by_key_thrust -from ..utils import prod, ceil_div - - -def _memcpy_ir(ib, out_ptr, data_ptr, shape): - fused = prod(shape) - with ib.new_scope(): - num_thread = int(tvm.target.Target.current(allow_none=False).max_num_threads) - num_blocks = ceil_div(fused, num_thread) - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", num_blocks) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", num_thread) - tid = bx * num_thread + tx - - with ib.if_scope(tid < fused): - out_ptr[tid] = data_ptr[tid] - - -def gen_ir_1d(data, indices, updates, axis, out, update_func): - """Generate scatter ir for 1d inputs - - Parameters - ---------- - data : tir.Tensor - The input data to the operator. - - indices : tir.Tensor - The index locations to update. - - updates : tir.Tensor - The values to update. - - axis : int - The axis to scatter on - - out : tir.Tensor - The output tensor. - - update_func: function - The function to be applied to a destination and the corresponding update. - - Returns - ------- - ret : tir - The computational ir. - """ - assert axis == 0 - n = data.shape[0] - - ib = tvm.tir.ir_builder.create() - - out_ptr = ib.buffer_ptr(out) - data_ptr = ib.buffer_ptr(data) - - _memcpy_ir(ib, out_ptr, data_ptr, data.shape) - - indices_ptr = ib.buffer_ptr(indices) - updates_ptr = ib.buffer_ptr(updates) - - ni = indices.shape[0] - - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", 1) - with ib.for_range(0, ni, name="i") as i: - index = indices_ptr[i] - with ib.if_scope(index < 0): - update_func(out_ptr, index + n, updates_ptr[i]) - with ib.else_scope(): - update_func(out_ptr, index, updates_ptr[i]) - - return ib.get() - - -def gen_ir_2d(data, indices, updates, axis, out, update_func): - """Generate scatter ir for 2d inputs - - Parameters - ---------- - data : tir.Tensor - The input data to the operator. - - indices : tir.Tensor - The index locations to update. - - updates : tir.Tensor - The values to update. - - axis : int - The axis to scatter on - - out : tir.Tensor - The output tensor. - - update_func: function - The function to be applied to a destination and the corresponding update - - Returns - ------- - ret : tir - The computational ir. - """ - n = data.shape[0] - c = data.shape[1] - - ib = tvm.tir.ir_builder.create() - - out_ptr = ib.buffer_ptr(out) - data_ptr = ib.buffer_ptr(data) - - _memcpy_ir(ib, out_ptr, data_ptr, data.shape) - - indices_ptr = ib.buffer_ptr(indices) - updates_ptr = ib.buffer_ptr(updates) - - ni = indices.shape[0] - ci = indices.shape[1] - - if axis == 0: - with ib.new_scope(): - j = te.thread_axis("blockIdx.x") - ib.scope_attr(j, "thread_extent", ci) - with ib.for_range(0, ni, name="i") as i: - idx = i * ci + j - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func(out_ptr, (index + n) * c + j, updates_ptr[idx]) - with ib.else_scope(): - update_func(out_ptr, index * c + j, updates_ptr[idx]) - else: - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", ni) - with ib.for_range(0, ci, name="j") as j: - idx = i * ci + j - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func(out_ptr, i * c + (index + c), updates_ptr[idx]) - with ib.else_scope(): - update_func(out_ptr, i * c + index, updates_ptr[idx]) - return ib.get() - - -def gen_ir_3d(data, indices, updates, axis, out, update_func): - """Generate scatter ir for 3d inputs - - Parameters - ---------- - data : tir.Tensor - The input data to the operator. - - indices : tir.Tensor - The index locations to update. - - updates : tir.Tensor - The values to update. - - axis : int - The axis to scatter on - - out : tir.Tensor - The output tensor. - - update_func: function - The function to be applied to a destination and the corresponding update - - Returns - ------- - ret : tir - The computational ir. - """ - warp_size = tvm.target.Target.current(False).thread_warp_size - - n = data.shape[0] - c = data.shape[1] - h = data.shape[2] - - ib = tvm.tir.ir_builder.create() - - out_ptr = ib.buffer_ptr(out) - data_ptr = ib.buffer_ptr(data) - - _memcpy_ir(ib, out_ptr, data_ptr, data.shape) - - indices_ptr = ib.buffer_ptr(indices) - updates_ptr = ib.buffer_ptr(updates) - ni = indices.shape[0] - ci = indices.shape[1] - hi = indices.shape[2] - - if axis == 0: - with ib.new_scope(): - j = te.thread_axis("blockIdx.x") - ib.scope_attr(j, "thread_extent", ci) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_: - k = k_ * warp_size + tx - with ib.if_scope(k < hi): - idx = (i * ci + j) * hi + k - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func(out_ptr, ((index + n) * c + j) * h + k, updates_ptr[idx]) - with ib.else_scope(): - update_func(out_ptr, (index * c + j) * h + k, updates_ptr[idx]) - elif axis == 1: - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", ni) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_: - k = k_ * warp_size + tx - with ib.if_scope(k < hi): - idx = (i * ci + j) * hi + k - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func(out_ptr, (i * c + (index + c)) * h + k, updates_ptr[idx]) - with ib.else_scope(): - update_func(out_ptr, (i * c + index) * h + k, updates_ptr[idx]) - else: - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", ni) - j = te.thread_axis("blockIdx.y") - ib.scope_attr(j, "thread_extent", ci) - with ib.for_range(0, hi, name="k") as k: - idx = (i * ci + j) * hi + k - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func(out_ptr, (i * c + j) * h + (index + h), updates_ptr[idx]) - with ib.else_scope(): - update_func(out_ptr, (i * c + j) * h + index, updates_ptr[idx]) - return ib.get() - - -def gen_ir_4d(data, indices, updates, axis, out, update_func): - """Generate scatter ir for 4d inputs - - Parameters - ---------- - data : tir.Tensor - The input data to the operator. - - indices : tir.Tensor - The index locations to update. - - updates : tir.Tensor - The values to update. - - axis : int - The axis to scatter on - - out : tir.Tensor - The output tensor. - - update_func: function - The function to be applied to a destination and the corresponding update - - Returns - ------- - ret : tir - The computational ir. - """ - warp_size = tvm.target.Target.current(False).thread_warp_size - - n = data.shape[0] - c = data.shape[1] - h = data.shape[2] - w = data.shape[3] - - ib = tvm.tir.ir_builder.create() - - out_ptr = ib.buffer_ptr(out) - data_ptr = ib.buffer_ptr(data) - _memcpy_ir(ib, out_ptr, data_ptr, data.shape) - - indices_ptr = ib.buffer_ptr(indices) - updates_ptr = ib.buffer_ptr(updates) - ni = indices.shape[0] - ci = indices.shape[1] - hi = indices.shape[2] - wi = indices.shape[3] - - if axis == 0: - with ib.new_scope(): - j = te.thread_axis("blockIdx.y") - ib.scope_attr(j, "thread_extent", ci) - k = te.thread_axis("blockIdx.z") - ib.scope_attr(k, "thread_extent", hi) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, ni, name="i") as i: - with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: - l = l_ * warp_size + tx - with ib.if_scope(l < wi): - idx = ((i * ci + j) * hi + k) * wi + l - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func( - out_ptr, (((index + n) * c + j) * h + k) * w + l, updates_ptr[idx] - ) - with ib.else_scope(): - update_func( - out_ptr, ((index * c + j) * h + k) * w + l, updates_ptr[idx] - ) - elif axis == 1: - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", ni) - k = te.thread_axis("blockIdx.z") - ib.scope_attr(k, "thread_extent", hi) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, ci, name="j") as j: - with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: - l = l_ * warp_size + tx - with ib.if_scope(l < wi): - idx = ((i * ci + j) * hi + k) * wi + l - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func( - out_ptr, ((i * c + (index + c)) * h + k) * w + l, updates_ptr[idx] - ) - with ib.else_scope(): - update_func( - out_ptr, ((i * c + index) * h + k) * w + l, updates_ptr[idx] - ) - elif axis == 2: - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", ni) - j = te.thread_axis("blockIdx.y") - ib.scope_attr(j, "thread_extent", ci) - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", warp_size) - with ib.for_range(0, hi, name="k") as k: - with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: - l = l_ * warp_size + tx - with ib.if_scope(l < wi): - idx = ((i * ci + j) * hi + k) * wi + l - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func( - out_ptr, ((i * c + j) * h + (index + h)) * w + l, updates_ptr[idx] - ) - with ib.else_scope(): - update_func( - out_ptr, ((i * c + j) * h + index) * w + l, updates_ptr[idx] - ) - else: - with ib.new_scope(): - i = te.thread_axis("blockIdx.x") - ib.scope_attr(i, "thread_extent", ni) - j = te.thread_axis("blockIdx.y") - ib.scope_attr(j, "thread_extent", ci) - k = te.thread_axis("blockIdx.z") - ib.scope_attr(k, "thread_extent", hi) - with ib.for_range(0, wi, name="l") as l: - idx = ((i * ci + j) * hi + k) * wi + l - index = indices_ptr[idx] - with ib.if_scope(index < 0): - update_func(out_ptr, ((i * c + j) * h + k) * w + (index + w), updates_ptr[idx]) - with ib.else_scope(): - update_func(out_ptr, ((i * c + j) * h + k) * w + index, updates_ptr[idx]) - return ib.get() - - -@autotvm.register_topi_compute("scatter.cuda") -def scatter(cfg, data, indices, updates, axis=0): - """Update data at positions defined by indices with values in updates - - Parameters - ---------- - data : relay.Expr - The input data to the operator. - - indices : relay.Expr - The index locations to update. - - updates : relay.Expr - The values to update. - - axis : int - The axis to scatter on - - Returns - ------- - ret : relay.Expr - The computed result. - """ - if axis < 0: - axis += len(data.shape) - assert axis >= 0 - assert axis < len(data.shape) - - rank = len(data.shape) - assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions" - - ir_funcs = { - 1: gen_ir_1d, - 2: gen_ir_2d, - 3: gen_ir_3d, - 4: gen_ir_4d, - } - - def update_func(dst_ptr, dst_index, update): - dst_ptr[dst_index] = update - - out_shape = data.shape - out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") - - cfg.add_flop(1) # A dummy value to satisfy AutoTVM - - out = te.extern( - [out_shape], - [data, indices, updates], - lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), - dtype=data.dtype, - out_buffers=[out_buf], - name="scatter_gpu", - tag="scatter_gpu", - ) - - return out - - -@autotvm.register_topi_schedule("scatter.cuda") -def schedule_scatter(_, outs): - return schedule_extern(outs) +from ..utils import ceil_div def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, out): From dd4cacc7b1cc0953e186e1e7c7233ad46b393cb7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 22 Feb 2023 08:57:22 +0300 Subject: [PATCH 22/22] remove scatter op implementation from topi. small clean code --- python/tvm/relay/op/strategy/generic.py | 1 + python/tvm/relay/op/strategy/rocm.py | 6 +- python/tvm/topi/scatter.py | 183 +----------------------- 3 files changed, 4 insertions(+), 186 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 2b12861e040e..b08d92a3cc93 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1558,6 +1558,7 @@ def scatter_elements_strategy(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_extern), name="scatter_elements.generic", ) + # TODO(vvchernov): implement specialized case (rank=1, reduction="update"), see cuda strategy return strategy diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 71d6b4524c95..d80f3479754b 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -117,12 +117,8 @@ def scatter_elements_cuda(attrs, inputs, out_type, target): ) rank = len(inputs[0].shape) - reduction = attrs.get("reduction", None) - if reduction is None: - reduction = b"update" - reduction = reduction.decode("utf-8") - with SpecializedCondition(rank == 1 and reduction == "update"): + with SpecializedCondition(rank == 1 and attrs.reduction == "update"): if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"): strategy.add_implementation( wrap_compute_scatter_elements(topi.cuda.scatter_via_sort), diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index 45629c005f79..799b3d16733f 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -14,191 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks -"""Scatter operator""" +# pylint: disable=invalid-name +"""ScatterND operator""" from tvm import te, tir # hide redefinition of min and max from tvm.tir import expr -@te.hybrid.script -def _scatter_1d(data, indices, updates): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - out[i] = data[i] - for i in range(indices.shape[0]): - out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] = updates[i] - return out - - -@te.hybrid.script -def _scatter_2d(data, indices, updates, axis): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - for j in range(data.shape[1]): - out[i, j] = data[i, j] - if axis == 0: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - out[ - indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j - ] = updates[i, j] - else: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - out[ - i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis] - ] = updates[i, j] - - return out - - -@te.hybrid.script -def _scatter_3d(data, indices, updates, axis): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - for j in range(data.shape[1]): - for k in range(data.shape[2]): - out[i, j, k] = data[i, j, k] - if axis == 0: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - out[ - indices[i, j, k] - if indices[i, j, k] >= 0 - else indices[i, j, k] + data.shape[axis], - j, - k, - ] = updates[i, j, k] - elif axis == 1: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - out[ - i, - indices[i, j, k] - if indices[i, j, k] >= 0 - else indices[i, j, k] + data.shape[axis], - k, - ] = updates[i, j, k] - else: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - out[ - i, - j, - indices[i, j, k] - if indices[i, j, k] >= 0 - else indices[i, j, k] + data.shape[axis], - ] = updates[i, j, k] - - return out - - -@te.hybrid.script -def _scatter_4d(data, indices, updates, axis): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - for j in range(data.shape[1]): - for k in range(data.shape[2]): - for l in range(data.shape[3]): - out[i, j, k, l] = data[i, j, k, l] - - if axis == 0: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - j, - k, - l, - ] = updates[i, j, k, l] - elif axis == 1: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - i, - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - k, - l, - ] = updates[i, j, k, l] - elif axis == 2: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - i, - j, - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - l, - ] = updates[i, j, k, l] - else: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - i, - j, - k, - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - ] = updates[i, j, k, l] - - return out - - -def scatter(data, indices, updates, axis=0): - """Update data at positions defined by indices with values in updates - - Parameters - ---------- - data : relay.Expr - The input data to the operator. - - indices : relay.Expr - The index locations to update. - - updates : relay.Expr - The values to update. - - axis : int - The axis to scatter on - - Returns - ------- - ret : relay.Expr - The computed result. - """ - if axis < 0: - axis += len(data.shape) - assert axis >= 0 - assert axis < len(data.shape) - - if len(data.shape) == 1: - return _scatter_1d(data, indices, updates) - if len(data.shape) == 2: - return _scatter_2d(data, indices, updates, axis) - if len(data.shape) == 3: - return _scatter_3d(data, indices, updates, axis) - if len(data.shape) == 4: - return _scatter_4d(data, indices, updates, axis) - raise ValueError("scatter only support for 1-4 dimensions") - - def _verify_scatter_nd_inputs(data, indices, updates): mdim = int(indices.shape[0]) assert mdim <= len(data.shape), (