Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,6 @@ struct ReshapeLikeAttrs : public tvm::AttrsNode<ReshapeLikeAttrs> {
}
}; // struct ReshapeLikeAttrs

struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
}
};

struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
String reduction;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/oneflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ class Scatter(OneFlowOpConverter):
@classmethod
def _impl_v1(cls, inputs, attrs, params):
axis = attrs.get("axis", 0)
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis)


class Unsqueeze(OneFlowOpConverter):
Expand Down
40 changes: 33 additions & 7 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,7 +1967,7 @@ def _impl_v11(cls, inputs, attr, params):
# Create a tensor of zeros then scatter our data through it.
zeros_tensor = _op.zeros(total_output_shape, data_type)
# We need to flatten all our tensors before scattering.
flat_tensor = _op.scatter(
flat_tensor = _op.scatter_elements(
_op.reshape(zeros_tensor, [-1]),
_op.reshape(indices, [-1]),
_op.reshape(data, [-1]),
Expand Down Expand Up @@ -2734,15 +2734,15 @@ def has_static_axes():
# Update the starts and ends according to axes if required.
if axes is not None:
data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
starts = _op.scatter(
starts = _op.scatter_elements(
_op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype),
axes,
starts,
axis=0,
)
ends = _op.scatter(data_shape, axes, ends, axis=0)
ends = _op.scatter_elements(data_shape, axes, ends, axis=0)
if steps is not None:
steps = _op.scatter(
steps = _op.scatter_elements(
_op.const([1] * data_rank, dtype=infer_type(steps).checked_type.dtype),
axes,
steps,
Expand Down Expand Up @@ -2848,9 +2848,35 @@ class Scatter(OnnxOpConverter):
"""Operator converter for Scatter."""

@classmethod
def _impl_v9(cls, inputs, attr, params):
def _args_check(cls, inputs, attr):
assert len(inputs) == 3, "Scatter takes 3 inputs (data, indices, updates), {} given".format(
len(inputs)
)
assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"]

data_rank = len(infer_shape(inputs[0]))
assert data_rank > 0, "Data rank higher than 0 is expected"
indices_shape = infer_shape(inputs[1])
indices_rank = len(indices_shape)
assert indices_rank == data_rank, "Indices rank is not the same as data one"
updates_shape = infer_shape(inputs[2])
updates_rank = len(updates_shape)
assert updates_rank == data_rank, "Updates rank is not the same as data one"

for i in range(data_rank):
assert (
indices_shape[i] == updates_shape[i]
), "Indices dimension size should be the same as updates one"

axis = attr.get("axis", 0)
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
assert -data_rank <= axis < data_rank, "Axis is out of bounds"

return axis

@classmethod
def _impl_v9(cls, inputs, attr, params):
axis = cls._args_check(inputs, attr)
return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis)


class ScatterElements(OnnxOpConverter):
Expand Down Expand Up @@ -4991,7 +5017,7 @@ def _index_put(cls, inputs, attr, params):
else:
mode = "add"
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
return _op.scatter_nd(in_tensor, index_tensor, values, mode)

@classmethod
def _reshape(cls, inputs, attr, params):
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,10 +1741,10 @@ def convert_scatter(g, op, block):
index = _op.transform.broadcast_to(index, shape)

if overwrite:
out = _op.scatter(x, index, updates, axis=0)
out = _op.scatter_elements(x, index, updates, axis=0)
else:
out = _op.scatter_elements(_op.zeros_like(x), index, updates, axis=0, reduction="add")
out += _op.scatter(x, index, _op.zeros_like(updates), axis=0)
out += _op.scatter_elements(x, index, _op.zeros_like(updates), axis=0)
g.add_node(op.output("Out")[0], out)


Expand Down Expand Up @@ -1826,7 +1826,7 @@ def convert_slice(g, op, block):

if len(axes) < dims:
if isinstance(starts, _expr.Expr):
starts = _op.scatter(
starts = _op.scatter_elements(
_op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype),
indices,
starts,
Expand Down Expand Up @@ -1857,7 +1857,7 @@ def convert_slice(g, op, block):

if len(axes) < dims:
if isinstance(ends, _expr.Expr):
ends = _op.scatter(
ends = _op.scatter_elements(
_expr.const(
np.array([np.iinfo(np.int32).max] * dims),
dtype=infer_type(ends).checked_type.dtype,
Expand Down Expand Up @@ -1892,7 +1892,7 @@ def convert_slice(g, op, block):

if len(axes) < dims:
if isinstance(strides, _expr.Expr):
strides = _op.scatter(
strides = _op.scatter_elements(
_expr.const(
np.array([1] * dims),
dtype=infer_type(strides).checked_type.dtype,
Expand Down
59 changes: 55 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def slice(self, inputs, input_types):
end[dim] = target_end
else:
target_end = _expr.const(target_end)
end = _op.scatter(
end = _op.scatter_elements(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
Expand All @@ -508,7 +508,7 @@ def slice(self, inputs, input_types):
ttype = self.infer_type(target_end).dtype
if str(ttype) != axis_dtype:
target_end = _op.cast(target_end, axis_dtype)
end = _op.scatter(
end = _op.scatter_elements(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
Expand Down Expand Up @@ -2552,11 +2552,62 @@ def nonzero_numpy(self, inputs, input_types):
return self.nonzero(inputs, input_types, is_numpy_style=False)

def scatter(self, inputs, input_types):
assert len(inputs) == 4 or len(inputs) == 5, (
"scatter takes 4 or 5 inputs: data, dim, index, src, reduce (optional), "
+ "but {} given".format(len(inputs))
)
data = inputs[0]
axis = int(inputs[1])
index = inputs[2]
src = inputs[3]
return _op.transform.scatter(data, index, src, axis)
if len(inputs) == 5:
reduce = inputs[4]
else:
reduce = "update"

data_shape = self.infer_shape(data)
data_rank = len(data_shape)
index_shape = self.infer_shape(index)
index_rank = len(index_shape)
# When index is empty, the operation returns data unchanged
if self.is_empty_shape(index_shape):
return data

if np.isscalar(src):
assert self.infer_type(src).dtype == "float", "Scalar source can be float only"
src = _op.broadcast_to_like(src, data_shape)
src_shape = data_shape
else:
src_shape = self.infer_shape(src)
src_rank = len(src_shape)
assert data_rank == index_rank, "Index rank is not the same as data rank"
assert data_rank == src_rank, "Src rank is not the same as data rank"

assert 0 <= axis < data_rank, "Dim is out of bounds"

for i in range(data_rank):
index_dim = index_shape[i]
src_dim = src_shape[i]
data_dim = data_shape[i]
# Skip check for dynamic dimensions
if not any([isinstance(index_dim, tvm.tir.Any), isinstance(src_dim, tvm.tir.Any)]):
assert index_dim <= src_dim, "Index dim size should be less than src one"
if i != axis and not any(
[isinstance(index_dim, tvm.tir.Any), isinstance(data_dim, tvm.tir.Any)]
):
assert index_dim <= data_dim, "Index dim size should be less than data one"

if reduce is None:
reduce = "update"
elif reduce == "multiply":
reduce = "mul"
assert reduce in [
"update",
"add",
"mul",
], 'reduce arg is expected from "add", "multiply" or None'

return _op.scatter_elements(data, index, src, axis, reduce)

def index_put(self, inputs, input_types):
in_tensor = inputs[0]
Expand All @@ -2569,7 +2620,7 @@ def index_put(self, inputs, input_types):
mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)
return _op.scatter_nd(in_tensor, index_tensor, values, mode)

def scalar_tensor(self, inputs, input_types):
data = inputs[0]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def do_where(levels, _):
scatter_indices = is_op("repeat")(scatter_indices)
scatter_indices = is_op("repeat")(scatter_indices)

scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i])
scatter_res = is_op("scatter_elements")(scatter_res, scatter_indices, roi_align_results[i])

return is_op("reshape")(scatter_res)

Expand Down
10 changes: 0 additions & 10 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,6 @@ def compute_strided_set(attrs, inputs, output_type):
# argwhere
_reg.register_strategy("argwhere", strategy.argwhere_strategy)

# scatter
@_reg.register_compute("scatter")
def compute_scatter(attrs, inputs, output_type):
"""Compute definition of scatter"""
return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]


_reg.register_strategy("scatter", strategy.scatter_strategy)

# sparse_fill_empty_rows
@_reg.register_compute("sparse_fill_empty_rows")
def compute_sparse_fill_empty_rows(attrs, inputs, output_type):
Expand Down Expand Up @@ -677,7 +668,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return ValueError("Does not support rank higher than 5 in argwhere")


_reg.register_shape_func("scatter", False, elemwise_shape_func)
_reg.register_shape_func("scatter_elements", False, elemwise_shape_func)
_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)

Expand Down
5 changes: 0 additions & 5 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,6 @@ class RequantizeAttrs(Attrs):
"""Attributes used in requantize operators"""


@tvm._ffi.register_object("relay.attrs.ScatterAttrs")
class ScatterAttrs(Attrs):
"""Attributes used in scatter operators"""


@tvm._ffi.register_object("relay.attrs.SequenceMaskAttrs")
class SequenceMaskAttrs(Attrs):
"""Attributes used in sequence_mask operators"""
Expand Down
30 changes: 8 additions & 22 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,44 +1062,30 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_strategy.register(["cuda", "gpu"])
def scatter_cuda(attrs, inputs, out_type, target):
"""scatter cuda strategy"""
@scatter_elements_strategy.register(["cuda", "gpu"])
def scatter_elements_cuda(attrs, inputs, out_type, target):
"""scatter elements cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
wrap_topi_schedule(topi.cuda.schedule_scatter),
name="scatter.cuda",
wrap_compute_scatter_elements(topi.cuda.scatter_elements),
wrap_topi_schedule(topi.cuda.schedule_extern),
name="scatter_elements.cuda",
plevel=10,
)

rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
with SpecializedCondition(rank == 1 and attrs.reduction == "update"):
if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_compute_scatter_elements(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.cuda",
plevel=9, # use the sequential version by default
)
return strategy


@scatter_elements_strategy.register(["cuda", "gpu"])
def scatter_elements_cuda(attrs, inputs, out_type, target):
"""scatter elements cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_elements(topi.cuda.scatter_elements),
wrap_topi_schedule(topi.cuda.schedule_extern),
name="scatter_elements.cuda",
plevel=10,
)
# TODO(vvchernov): There is possible specification for rank=1 as for scatter
return strategy


@scatter_nd_strategy.register(["cuda", "gpu"])
def scatter_nd_cuda(attrs, inputs, out_type, target):
"""scatter_nd cuda strategy"""
Expand Down
22 changes: 1 addition & 21 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,27 +1548,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
return strategy


# scatter
@override_native_generic_func("scatter_strategy")
def scatter_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.scatter),
wrap_topi_schedule(topi.generic.schedule_scatter),
name="scatter.generic",
)
return strategy


def wrap_compute_scatter(topi_compute):
"""Wrap scatter topi compute"""

def _compute_scatter(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]

return _compute_scatter


# scatter_elements
@override_native_generic_func("scatter_elements_strategy")
def scatter_elements_strategy(attrs, inputs, out_type, target):
Expand All @@ -1579,6 +1558,7 @@ def scatter_elements_strategy(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_elements.generic",
)
# TODO(vvchernov): implement specialized case (rank=1, reduction="update"), see cuda strategy
return strategy


Expand Down
14 changes: 7 additions & 7 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,23 @@ def argsort_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_strategy.register(["rocm"])
def scatter_cuda(attrs, inputs, out_type, target):
@scatter_elements_strategy.register(["rocm"])
def scatter_elements_cuda(attrs, inputs, out_type, target):
"""scatter rocm strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
wrap_topi_schedule(topi.cuda.schedule_scatter),
name="scatter.rocm",
wrap_compute_scatter_elements(topi.cuda.scatter_elements),
wrap_topi_schedule(topi.cuda.schedule_extern),
name="scatter_elements.rocm",
plevel=10,
)

rank = len(inputs[0].shape)

with SpecializedCondition(rank == 1):
with SpecializedCondition(rank == 1 and attrs.reduction == "update"):
if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_via_sort),
wrap_compute_scatter_elements(topi.cuda.scatter_via_sort),
wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
name="scatter_via_sort.rocm",
plevel=9, # use the sequential version by default
Expand Down
Loading