Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Reduction mode of the scatter elements, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
"either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".");
}
};

Expand Down
3 changes: 0 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,9 +2766,6 @@ def scatter_reduce(self, inputs, input_types):
reduce = "min"
elif reduce == "amax":
reduce = "max"
else: # reduce == "mean"
# TODO(vvchernov): support mean reduction
raise NotImplementedError("Mean reduction has not been supported yet!")

return _op.scatter_elements(data, index, src, axis=dim, reduction=reduce)

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,11 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
The axis to scatter elements on. It is zero by default.

reduction : string, optional
The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"]
The reduction mode for scatter. Choise is from ["update", "add", "mul", "mean", "min", max"]
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If mul, the input data will be multiplied on the update values
If mean, the input data will be mean between the update values and the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/topi/cuda/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
The axis to scatter on. It is zero by default.

reduction : optional, string
The update mode for the algorithm, either "update", "add", "mul", "min" or "max"
The update mode for the algorithm, either "update", "add", "mul", "mean", "min" or "max"
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If mul, the input data will be multiplied on the update values
If mean, the input data will be mean between the update values and the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Expand All @@ -258,6 +259,9 @@ def add_func(dst_ptr, dst_index, update):
def mul_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] *= update

def mean_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2

def min_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update)

Expand All @@ -271,6 +275,8 @@ def max_func(dst_ptr, dst_index, update):
reduce_func = add_func
elif reduction == "mul":
reduce_func = mul_func
elif reduction == "mean":
reduce_func = mean_func
elif reduction == "min":
reduce_func = min_func
elif reduction == "max":
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/topi/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
The update mode for the algorithm, either "update", "add", "mul", "min" or "max"
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If mul, the input data will be multiplied on the update values
If mean, the input data will be mean between the update values and the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default
Expand Down Expand Up @@ -133,6 +134,9 @@ def add_func(dst_ptr, dst_index, update):
def mul_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] *= update

def mean_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2

def min_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update)

Expand All @@ -146,13 +150,15 @@ def max_func(dst_ptr, dst_index, update):
reduce_func = add_func
elif reduction == "mul":
reduce_func = mul_func
elif reduction == "mean":
reduce_func = mean_func
elif reduction == "min":
reduce_func = min_func
elif reduction == "max":
reduce_func = max_func
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, min, max]:", reduction
"scatter_elements reduction not in [update, add, mul, mean, min, max]:", reduction
)

out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
Expand Down
6 changes: 2 additions & 4 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4252,16 +4252,14 @@ def test_fn_scatter_reduce(dim, reduce):
in_src = torch.rand(2, 5) - 1

targets = ["llvm", "cuda"]
# TODO(vvchernov): support test of mean reduction and include_self=False
for reduce in ["sum", "prod", "amin", "amax"]:
for reduce in ["sum", "prod", "amin", "amax", "mean"]:
verify_trace_model(test_fn_scatter_reduce(0, reduce), [in_data, in_index, in_src], targets)

in_data = torch.rand(2, 4) - 1
in_index = torch.tensor([[2], [3]])
in_src = torch.rand(2, 1) - 1

# TODO(vvchernov): support test of mean reduction and include_self=False
for reduce in ["sum", "prod", "amin", "amax"]:
for reduce in ["sum", "prod", "amin", "amax", "mean"]:
verify_trace_model(test_fn_scatter_reduce(1, reduce), [in_data, in_index, in_src], targets)


Expand Down